Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: add basic auth support #1

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
15 changes: 12 additions & 3 deletions cmd/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cmd
import (
"context"
"fmt"
"net/http"

"github.com/spf13/cobra"

Expand All @@ -15,9 +14,19 @@ func lock(group, id, url *string) *cobra.Command {
Use: "recursive-lock",
Short: "Try to reserve (lock) a slot for rebooting",
RunE: func(cmd *cobra.Command, args []string) error {
httpClient := http.DefaultClient
if id == nil {
var err error
id, err = machineID()
if err != nil {
return fmt.Errorf("getting machine ID: %w", err)
}
}

c, err := client.New(*url, *group, *id, httpClient)
c, err := client.New(&client.Config{
ID: *id,
Group: *group,
URL: *url,
})
if err != nil {
return fmt.Errorf("building the client: %w", err)
}
Expand Down
16 changes: 16 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
package cmd

import (
"fmt"
"io/ioutil"

"github.com/spf13/cobra"
)

Expand All @@ -20,3 +23,16 @@ func Command() *cobra.Command {

return cli
}

// machineID is a helper to return unique ID
// of the machine.
func machineID() (*string, error) {
id, err := ioutil.ReadFile("/etc/machine-id")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this path configurable would make the more testable, but I guess it would have to be done via e.g. --id-from-file flag or something. Also, we don't have any tests in place for this, so would be more effort.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get your point, however we should not forget that machineID() is a helper to provide a default value for the --id flag. In case one wants to use a different ID he can just use --id my-id or even --id $(cat /tmp/my-id-file).

If we really want to test this helper, we can still rely on fs abstraction with afero.FS but it becomes a bit overkill IMHO. :)

if err != nil {
return nil, fmt.Errorf("reading machine ID from file: %w", err)
}

sid := string(id)

return &sid, nil
}
15 changes: 12 additions & 3 deletions cmd/unlock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cmd
import (
"context"
"fmt"
"net/http"

"github.com/spf13/cobra"

Expand All @@ -15,9 +14,19 @@ func unlock(group, id, url *string) *cobra.Command {
Use: "unlock-if-held",
Short: "Try to release (unlock) a slot that it was previously holding",
RunE: func(cmd *cobra.Command, args []string) error {
httpClient := http.DefaultClient
if id == nil {
var err error
id, err = machineID()
if err != nil {
return fmt.Errorf("getting machine ID: %w", err)
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is now duplicate. Can we refactor?


c, err := client.New(*url, *group, *id, httpClient)
c, err := client.New(&client.Config{
ID: *id,
Group: *group,
URL: *url,
})
if err != nil {
return fmt.Errorf("building the client: %w", err)
}
Expand Down
50 changes: 50 additions & 0 deletions pkg/client/authentication.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package client

import (
"context"
"fmt"
"net/http"
)

type basicAuthRoundTripper struct {
username string
password string
rt http.RoundTripper
}

// RoundTrip is required to implement RoundTripper interface.
// We check if an authorization token is already set, if not we set it.
// We return the initial RoundTripper to chain it in the whole process.
func (b *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) != 0 {
resp, err := b.rt.RoundTrip(req)
if err != nil {
return nil, fmt.Errorf("inner round trip error: %w", err)
}

return resp, nil
}

req = req.Clone(context.TODO())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
req = req.Clone(context.TODO())
req = req.Clone(req.Context())

Otherwise using round tripper swallows the context control on the request:

package client_test

import (
  "context"
  "errors"
  "net/http"
  "testing"
  "time"

  "github.com/flatcar-linux/fleetlock/pkg/client"
)

func Test_Cancelling_context_for_request_performed_with_http_client_with_basic_auth_round_tripper_cancels_the_request(t *testing.T) {
  httpClient := http.Client{
    Transport: client.NewBasicAuthRoundTripper("foo", "bar", nil),
  }

  requestTimeout := time.Second

  ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
  t.Cleanup(cancel)

  req, err := http.NewRequestWithContext(ctx, "GET", "http://10.255.255.1", nil)
  if err != nil {
    t.Fatal(err)
  }

  errCh := make(chan error, 1)
  go func() {
    _, err = httpClient.Do(req)
    errCh <- err
  }()

  testDeadline := time.NewTimer(2 * requestTimeout)
  select {
  case <-testDeadline.C:
    t.Fatalf("Expected request to return before the deadline")
  case err := <-errCh:
    if err != nil && !errors.Is(err, context.DeadlineExceeded) {
      t.Fatal(err)
    }
  }
}

req.SetBasicAuth(b.username, b.password)

resp, err := b.rt.RoundTrip(req)
if err != nil {
return nil, fmt.Errorf("inner round trip error: %w", err)
}

return resp, nil
}

// NewBasicAuthRoundTripper returns a basicAuthRoundTripper with username and password.
func NewBasicAuthRoundTripper(username, password string, rt http.RoundTripper) http.RoundTripper {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can let rt to be nil and use net/http.Transport{} as default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like round tripper should have it's own set of tests, separate from the client, to make the effort required for testing it smaller.

if rt == nil {
rt = &http.Transport{}
}

return &basicAuthRoundTripper{
username: username,
password: password,
rt: rt,
}
}
46 changes: 32 additions & 14 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,6 @@ type Client struct {
http HTTPClient
}

// New builds a FleetLock client.
func New(baseServerURL, group, id string, c HTTPClient) (*Client, error) {
if _, err := url.ParseRequestURI(baseServerURL); err != nil {
return nil, fmt.Errorf("parsing URL: %w", err)
}

return &Client{
baseServerURL: baseServerURL,
http: c,
group: group,
id: id,
}, nil
}

// RecursiveLock tries to reserve (lock) a slot for rebooting.
func (c *Client) RecursiveLock(ctx context.Context) error {
req, err := c.generateRequest(ctx, "v1/pre-reboot")
Expand Down Expand Up @@ -148,3 +134,35 @@ func handleResponse(resp *http.Response) error {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
}

// New builds a FleetLock client.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we move this function to the bottom of the file? I think right after exported struct definition it's very good, as this what potential reader will be looking for initially.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - It might be related to a rebase onto main.

func New(cfg *Config) (*Client, error) {
fleetlock := &Client{
baseServerURL: cfg.URL,
http: cfg.HTTP,
group: cfg.Group,
id: cfg.ID,
}

if fleetlock.id == "" {
return nil, fmt.Errorf("ID is required")
}

if fleetlock.baseServerURL == "" {
return nil, fmt.Errorf("URL is required")
}

if _, err := url.ParseRequestURI(fleetlock.baseServerURL); err != nil {
return nil, fmt.Errorf("parsing URL: %w", err)
}

if fleetlock.group == "" {
fleetlock.group = "default"
}

if fleetlock.http == nil {
fleetlock.http = http.DefaultClient
}

return fleetlock, nil
}
Loading