diff --git a/cmd/lock.go b/cmd/lock.go index 5b920b8..b2164f2 100644 --- a/cmd/lock.go +++ b/cmd/lock.go @@ -3,7 +3,6 @@ package cmd import ( "context" "fmt" - "net/http" "github.com/spf13/cobra" @@ -15,9 +14,15 @@ func lock(group, id, url *string) *cobra.Command { Use: "recursive-lock", Short: "Try to reserve (lock) a slot for rebooting", RunE: func(cmd *cobra.Command, args []string) error { - httpClient := http.DefaultClient + if err := checkID(id); err != nil { + return fmt.Errorf("checking ID: %w", err) + } - c, err := client.New(*url, *group, *id, httpClient) + c, err := client.New(&client.Config{ + ID: *id, + Group: *group, + URL: *url, + }) if err != nil { return fmt.Errorf("building the client: %w", err) } diff --git a/cmd/root.go b/cmd/root.go index 37b9ec7..56c4cb2 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,6 +2,9 @@ package cmd import ( + "fmt" + "io/ioutil" + "github.com/spf13/cobra" ) @@ -20,3 +23,32 @@ func Command() *cobra.Command { return cli } + +// machineID is a helper to return unique ID +// of the machine. +func machineID() (string, error) { + id, err := ioutil.ReadFile("/etc/machine-id") + if err != nil { + return "", fmt.Errorf("reading machine ID from file: %w", err) + } + + return string(id), nil +} + +// checkID asserts that the ID is not nil, if it's the case +// it uses `machineID` to generate a default one. +func checkID(id *string) error { + // the ID is set and it's not empty. + if id != nil && *id != "" { + return nil + } + + i, err := machineID() + if err != nil { + return fmt.Errorf("getting default machine ID: %w", err) + } + + *id = i + + return nil +} diff --git a/cmd/unlock.go b/cmd/unlock.go index 7e82ef9..d13cd15 100644 --- a/cmd/unlock.go +++ b/cmd/unlock.go @@ -3,7 +3,6 @@ package cmd import ( "context" "fmt" - "net/http" "github.com/spf13/cobra" @@ -15,9 +14,15 @@ func unlock(group, id, url *string) *cobra.Command { Use: "unlock-if-held", Short: "Try to release (unlock) a slot that it was previously holding", RunE: func(cmd *cobra.Command, args []string) error { - httpClient := http.DefaultClient + if err := checkID(id); err != nil { + return fmt.Errorf("checking ID: %w", err) + } - c, err := client.New(*url, *group, *id, httpClient) + c, err := client.New(&client.Config{ + ID: *id, + Group: *group, + URL: *url, + }) if err != nil { return fmt.Errorf("building the client: %w", err) } diff --git a/pkg/client/authentication.go b/pkg/client/authentication.go new file mode 100644 index 0000000..a8c5a58 --- /dev/null +++ b/pkg/client/authentication.go @@ -0,0 +1,50 @@ +package client + +import ( + "context" + "fmt" + "net/http" +) + +type basicAuthRoundTripper struct { + username string + password string + rt http.RoundTripper +} + +// RoundTrip is required to implement RoundTripper interface. +// We check if an authorization token is already set, if not we set it. +// We return the initial RoundTripper to chain it in the whole process. +func (b *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if len(req.Header.Get("Authorization")) != 0 { + resp, err := b.rt.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("inner round trip error: %w", err) + } + + return resp, nil + } + + req = req.Clone(context.TODO()) + req.SetBasicAuth(b.username, b.password) + + resp, err := b.rt.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("inner round trip error: %w", err) + } + + return resp, nil +} + +// NewBasicAuthRoundTripper returns a basicAuthRoundTripper with username and password. +func NewBasicAuthRoundTripper(username, password string, rt http.RoundTripper) http.RoundTripper { + if rt == nil { + rt = &http.Transport{} + } + + return &basicAuthRoundTripper{ + username: username, + password: password, + rt: rt, + } +} diff --git a/pkg/client/client.go b/pkg/client/client.go index 3b35b6e..cfd7833 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -48,17 +48,35 @@ type Client struct { } // New builds a FleetLock client. -func New(baseServerURL, group, id string, c HTTPClient) (*Client, error) { - if _, err := url.ParseRequestURI(baseServerURL); err != nil { +func New(cfg *Config) (*Client, error) { + fleetlock := &Client{ + baseServerURL: cfg.URL, + http: cfg.HTTP, + group: cfg.Group, + id: cfg.ID, + } + + if fleetlock.id == "" { + return nil, fmt.Errorf("ID is required") + } + + if fleetlock.baseServerURL == "" { + return nil, fmt.Errorf("URL is required") + } + + if _, err := url.ParseRequestURI(fleetlock.baseServerURL); err != nil { return nil, fmt.Errorf("parsing URL: %w", err) } - return &Client{ - baseServerURL: baseServerURL, - http: c, - group: group, - id: id, - }, nil + if fleetlock.group == "" { + fleetlock.group = "default" + } + + if fleetlock.http == nil { + fleetlock.http = http.DefaultClient + } + + return fleetlock, nil } // RecursiveLock tries to reserve (lock) a slot for rebooting. diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index ede89cc..259d631 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -3,15 +3,21 @@ package client_test import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io/ioutil" "net/http" + "reflect" "testing" "github.com/flatcar-linux/fleetlock/pkg/client" ) +var fleetlockHeaders = http.Header{ + "Fleet-Lock-Protocol": []string{"true"}, +} + type httpClient struct { do func(req *http.Request) (*http.Response, error) r *http.Request @@ -23,10 +29,16 @@ func (m *httpClient) Do(req *http.Request) (*http.Response, error) { return m.do(req) } +func (m *httpClient) RoundTrip(req *http.Request) (*http.Response, error) { + m.r = req + + return m.do(req) +} + func TestBadURL(t *testing.T) { t.Parallel() - _, err := client.New("this is not an URL", "default", "1234", nil) + _, err := client.New(&client.Config{URL: "this is not an URL", ID: "1234"}) if err == nil { t.Fatalf("should get an error") } @@ -48,32 +60,70 @@ func TestClient(t *testing.T) { expErr error body []byte doErr error + cfg *client.Config + expCfg *client.Config }{ { statusCode: 200, + cfg: &client.Config{ + ID: "1234", + URL: "http://1.2.3.4", + }, + expCfg: &client.Config{ + Group: "default", + }, }, { statusCode: 500, expErr: errors.New("fleetlock error: this is an error (error_kind)"), body: []byte(`{"kind":"error_kind","value":"this is an error"}`), + cfg: &client.Config{ + ID: "1234", + URL: "http://1.2.3.4", + }, + expCfg: &client.Config{ + Group: "default", + }, }, { statusCode: 500, expErr: errors.New("unmarshalling error: invalid character '\"' after object key:value pair"), body: []byte(`{"kind":"error_kind" "value":"this is an error"}`), + cfg: &client.Config{ + ID: "1234", + URL: "http://1.2.3.4", + Group: "lokomotive", + }, + expCfg: &client.Config{ + Group: "lokomotive", + }, }, { statusCode: 100, expErr: errors.New("unexpected status code: 100"), + cfg: &client.Config{ + ID: "1234", + URL: "http://1.2.3.4", + }, + expCfg: &client.Config{ + Group: "default", + }, }, { expErr: errors.New("doing the request: connection refused"), doErr: errors.New("connection refused"), + cfg: &client.Config{ + ID: "1234", + URL: "http://1.2.3.4", + }, + expCfg: &client.Config{ + Group: "default", + }, }, } { test := test - newClient := func(statusCode int, body []byte, doErr error) (*httpClient, *client.Client) { + newClient := func(cfg *client.Config, statusCode int, body []byte, doErr error) (*httpClient, *client.Client) { h := &httpClient{ do: func(req *http.Request) (*http.Response, error) { return &http.Response{ @@ -83,7 +133,9 @@ func TestClient(t *testing.T) { }, } - c, err := client.New("http://1.2.3.4", "default", "1234", h) + cfg.HTTP = h + + c, err := client.New(cfg) if err != nil { t.Fatalf("Unexpected error creating client: %v", err) } @@ -91,10 +143,29 @@ func TestClient(t *testing.T) { return h, c } + getPayload := func(h *httpClient) *client.Payload { + b, err := h.r.GetBody() + if err != nil { + t.Fatalf("unable to get body from request: %v", err) + } + + payload, err := ioutil.ReadAll(b) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + + var p client.Payload + if err := json.Unmarshal(payload, &p); err != nil { + t.Fatalf("unable to unmarshal payload: %v", err) + } + + return &p + } + t.Run(fmt.Sprintf("UnlockIfHeld_%d", test.statusCode), func(t *testing.T) { t.Parallel() - h, c := newClient(test.statusCode, test.body, test.doErr) + h, c := newClient(test.cfg, test.statusCode, test.body, test.doErr) err := c.UnlockIfHeld(ctx) if err != nil && err.Error() != test.expErr.Error() { @@ -106,12 +177,22 @@ func TestClient(t *testing.T) { if h.r.URL.String() != expURL { t.Fatalf("should have %s for URL, got: %s", expURL, h.r.URL.String()) } + + if !reflect.DeepEqual(h.r.Header, fleetlockHeaders) { + t.Fatalf("should have %v for headers, got: %s", fleetlockHeaders, h.r.Header) + } + + payload := getPayload(h) + + if payload.ClientParams.Group != test.expCfg.Group { + t.Fatalf("payload's group should be : %s, got: %s", test.expCfg.Group, payload.ClientParams.Group) + } }) t.Run(fmt.Sprintf("RecursiveLock_%d", test.statusCode), func(t *testing.T) { t.Parallel() - h, c := newClient(test.statusCode, test.body, test.doErr) + h, c := newClient(test.cfg, test.statusCode, test.body, test.doErr) err := c.RecursiveLock(ctx) if err != nil && err.Error() != test.expErr.Error() { @@ -123,6 +204,16 @@ func TestClient(t *testing.T) { if h.r.URL.String() != expURL { t.Fatalf("should have %s for URL, got: %s", expURL, h.r.URL.String()) } + + if !reflect.DeepEqual(h.r.Header, fleetlockHeaders) { + t.Fatalf("should have %v for headers, got: %s", fleetlockHeaders, h.r.Header) + } + + payload := getPayload(h) + + if payload.ClientParams.Group != test.expCfg.Group { + t.Fatalf("payload's group should be : %s, got: %s", test.expCfg.Group, payload.ClientParams.Group) + } }) } } @@ -145,7 +236,11 @@ func Test_Client_use_given_context_for_requests(t *testing.T) { }, } - c, err := client.New("http://1.2.3.4", "default", "1234", h) + c, err := client.New(&client.Config{ + URL: "http://1.2.3.4", + ID: "1234", + HTTP: h, + }) if err != nil { t.Fatalf("Unexpected error creating client: %v", err) } @@ -160,3 +255,72 @@ func Test_Client_use_given_context_for_requests(t *testing.T) { t.Fatalf("Unexpected error while unlocking: %v", err) } } + +func TestBasicAuth(t *testing.T) { + t.Parallel() + + var ( + username = "flatcar" + password = "p4ssw0rd" + ) + + ctx := context.Background() + + tr := &httpClient{ + do: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + }, nil + }, + } + + h := http.Client{ + Transport: client.NewBasicAuthRoundTripper(username, password, tr), + } + + c, err := client.New(&client.Config{ID: "1234", HTTP: &h, URL: "http://1.2.3.4"}) + if err != nil { + t.Fatalf("Unexpected error creating client: %v", err) + } + + err = c.RecursiveLock(ctx) + if err != nil { + t.Fatalf("should have nil for err, got: %v", err) + } + + u, p, ok := tr.r.BasicAuth() + if u != username || p != password || !ok { + t.Fatalf("basic auth creds do not match") + } +} + +func TestRequiredParameters(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + cfg *client.Config + err error + }{ + { + cfg: &client.Config{ + URL: "http://1.2.3.4", + }, + err: errors.New("ID is required"), + }, + { + cfg: &client.Config{ + ID: "1234", + }, + err: errors.New("URL is required"), + }, + } { + _, err := client.New(test.cfg) + if err == nil { + t.Fatal("error should not be nil") + } + + if err.Error() != test.err.Error() { + t.Fatalf("error should be: %v, got: %v", test.err, err) + } + } +} diff --git a/pkg/client/config.go b/pkg/client/config.go new file mode 100644 index 0000000..6fff6ec --- /dev/null +++ b/pkg/client/config.go @@ -0,0 +1,16 @@ +package client + +// Config is the dedicated fleetlock client config. +type Config struct { + // Group of the instance. Defaults to "default" + Group string + // ID of the instance, must be unique and should persist across reboot. + // Required. + ID string + // HTTP client to use - can be used to implement authentication logic. + // Defaults to `http.DefaultClient` + HTTP HTTPClient + // URL of the FleetLock server implementation. + // Required. + URL string +}