From 17a73efaeec1d55893c8970dad1045166581f8d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Henrique=20Guard=C3=A3o=20Gandarez?= Date: Tue, 5 Nov 2024 13:15:28 -0300 Subject: [PATCH 1/6] Inject logger using context --- cmd/api/api.go | 21 +- cmd/configread/configread.go | 3 +- cmd/configwrite/configwrite.go | 11 +- cmd/configwrite/configwrite_test.go | 19 +- cmd/fileexperts/fileexperts.go | 23 +- cmd/fileexperts/fileexperts_test.go | 39 +- cmd/heartbeat/heartbeat.go | 86 +-- cmd/heartbeat/heartbeat_test.go | 237 ++++---- cmd/logfile/logfile.go | 5 +- cmd/logfile/logfile_test.go | 5 +- cmd/offline/offline.go | 48 +- cmd/offlinecount/offlinecount.go | 7 +- cmd/offlinecount/offlinecount_test.go | 5 +- cmd/offlineprint/offlineprint.go | 9 +- cmd/offlineprint/offlineprint_test.go | 5 +- cmd/offlinesync/offlinesync.go | 65 ++- cmd/offlinesync/offlinesync_internal_test.go | 3 +- cmd/offlinesync/offlinesync_test.go | 11 +- cmd/params/params.go | 82 +-- cmd/params/params_internal_test.go | 16 +- cmd/params/params_test.go | 514 ++++++++++-------- cmd/run.go | 182 ++++--- cmd/run_internal_test.go | 42 +- cmd/run_test.go | 102 ++-- cmd/today/today.go | 17 +- cmd/today/today_test.go | 11 +- cmd/todaygoal/todaygoal.go | 21 +- cmd/todaygoal/todaygoal_test.go | 13 +- cmd/version.go | 3 +- main_test.go | 63 ++- pkg/api/api.go | 18 +- pkg/api/diagnostic.go | 14 +- pkg/api/diagnostic_test.go | 3 +- pkg/api/fileexperts.go | 9 +- pkg/api/fileexperts_test.go | 13 +- pkg/api/goal.go | 5 +- pkg/api/goal_test.go | 19 +- pkg/api/heartbeat.go | 39 +- pkg/api/heartbeat_test.go | 23 +- pkg/api/option.go | 13 +- pkg/api/option_test.go | 44 +- pkg/api/statusbar.go | 5 +- pkg/api/statusbar_test.go | 18 +- pkg/api/transport.go | 15 +- pkg/api/transport_other.go | 3 +- pkg/api/transport_windows.go | 7 +- pkg/apikey/apikey.go | 21 +- pkg/apikey/apikey_test.go | 22 +- pkg/backoff/backoff.go | 32 +- pkg/backoff/backoff_internal_test.go | 27 +- pkg/backoff/backoff_test.go | 75 +-- pkg/deps/c.go | 7 +- pkg/deps/c_test.go | 3 +- pkg/deps/cpp.go | 7 +- pkg/deps/cpp_test.go | 3 +- pkg/deps/csharp.go | 7 +- pkg/deps/csharp_test.go | 3 +- pkg/deps/deps.go | 32 +- pkg/deps/deps_test.go | 27 +- pkg/deps/elm.go | 7 +- pkg/deps/elm_test.go | 3 +- pkg/deps/golang.go | 7 +- pkg/deps/golang_test.go | 3 +- pkg/deps/haskell.go | 7 +- pkg/deps/haskell_test.go | 3 +- pkg/deps/haxe.go | 7 +- pkg/deps/haxe_test.go | 3 +- pkg/deps/html.go | 7 +- pkg/deps/html_test.go | 5 +- pkg/deps/java.go | 7 +- pkg/deps/java_test.go | 3 +- pkg/deps/javascript.go | 7 +- pkg/deps/javascript_test.go | 5 +- pkg/deps/json.go | 7 +- pkg/deps/json_test.go | 5 +- pkg/deps/kotlin.go | 7 +- pkg/deps/kotlin_test.go | 3 +- pkg/deps/objectivec.go | 7 +- pkg/deps/objectivec_test.go | 3 +- pkg/deps/php.go | 7 +- pkg/deps/php_test.go | 3 +- pkg/deps/python.go | 7 +- pkg/deps/python_test.go | 3 +- pkg/deps/rust.go | 7 +- pkg/deps/rust_test.go | 3 +- pkg/deps/scala.go | 7 +- pkg/deps/scala_test.go | 3 +- pkg/deps/swift.go | 7 +- pkg/deps/swift_test.go | 3 +- pkg/deps/unknown.go | 3 +- pkg/deps/unknown_test.go | 5 +- pkg/deps/vbnet.go | 7 +- pkg/deps/vbnet_test.go | 3 +- pkg/fileexperts/fileexperts.go | 7 +- pkg/fileexperts/validation.go | 11 +- pkg/fileexperts/validation_test.go | 5 +- pkg/filestats/filestats.go | 22 +- pkg/filestats/filestats_test.go | 13 +- pkg/filter/filter.go | 36 +- pkg/filter/filter_test.go | 27 +- .../entity_modifier_internal_test.go | 9 +- pkg/heartbeat/entity_modify.go | 21 +- pkg/heartbeat/entity_modify_test.go | 5 +- pkg/heartbeat/format.go | 34 +- pkg/heartbeat/format_test.go | 7 +- pkg/heartbeat/heartbeat.go | 23 +- pkg/heartbeat/heartbeat_test.go | 19 +- pkg/heartbeat/sanitize.go | 26 +- pkg/heartbeat/sanitize_test.go | 73 +-- pkg/ini/ini.go | 45 +- pkg/ini/ini_test.go | 64 ++- pkg/language/chroma.go | 43 +- pkg/language/language.go | 36 +- pkg/language/language_test.go | 49 +- pkg/lexer/ruby.go | 3 - pkg/log/context.go | 39 ++ pkg/log/log.go | 161 ++++-- pkg/log/log_test.go | 45 ++ pkg/metrics/metrics.go | 15 +- pkg/offline/legacy.go | 5 +- pkg/offline/legacy_test.go | 5 +- pkg/offline/offline.go | 113 ++-- pkg/offline/offline_test.go | 61 ++- pkg/project/file.go | 16 +- pkg/project/file_test.go | 13 +- pkg/project/filter.go | 12 +- pkg/project/filter_test.go | 5 +- pkg/project/git.go | 83 +-- pkg/project/git_test.go | 39 +- pkg/project/map.go | 15 +- pkg/project/map_test.go | 20 +- pkg/project/mercurial.go | 14 +- pkg/project/mercurial_test.go | 7 +- pkg/project/project.go | 57 +- pkg/project/project_test.go | 81 +-- pkg/project/subversion.go | 17 +- pkg/project/subversion_test.go | 5 +- pkg/project/tfvc.go | 5 +- pkg/project/tfvc_test.go | 5 +- pkg/regex/regex.go | 54 +- pkg/regex/regex_internal_test.go | 9 +- pkg/remote/remote.go | 119 ++-- pkg/remote/remote_test.go | 173 +++--- pkg/system/system_linux.go | 7 +- pkg/system/system_other.go | 3 +- pkg/system/system_other_test.go | 22 + 146 files changed, 2466 insertions(+), 1693 deletions(-) create mode 100644 pkg/log/context.go create mode 100644 pkg/log/log_test.go create mode 100644 pkg/system/system_other_test.go diff --git a/cmd/api/api.go b/cmd/api/api.go index b8dca1b6..59cb88fe 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "fmt" "strings" @@ -13,7 +14,7 @@ import ( // NewClient initializes a new api client with all options following the // passed in parameters. -func NewClient(params paramscmd.API) (*api.Client, error) { +func NewClient(ctx context.Context, params paramscmd.API) (*api.Client, error) { withAuth, err := api.WithAuth(api.BasicAuth{ Secret: params.Key, }) @@ -21,23 +22,25 @@ func NewClient(params paramscmd.API) (*api.Client, error) { return nil, fmt.Errorf("failed to set up auth option on api client: %w", err) } - return newClient(params, withAuth) + return newClient(ctx, params, withAuth) } // NewClientWithoutAuth initializes a new api client with all options following the // passed in parameters and disabled authentication. -func NewClientWithoutAuth(params paramscmd.API) (*api.Client, error) { - return newClient(params) +func NewClientWithoutAuth(ctx context.Context, params paramscmd.API) (*api.Client, error) { + return newClient(ctx, params) } // newClient contains the logic of client initialization, except auth initialization. -func newClient(params paramscmd.API, opts ...api.Option) (*api.Client, error) { +func newClient(ctx context.Context, params paramscmd.API, opts ...api.Option) (*api.Client, error) { opts = append(opts, api.WithTimeout(params.Timeout)) opts = append(opts, api.WithHostname(strings.TrimSpace(params.Hostname))) + logger := log.Extract(ctx) + tz, err := timezone() if err != nil { - log.Debugf("failed to detect local timezone: %s", err) + logger.Debugf("failed to detect local timezone: %s", err) } else { opts = append(opts, api.WithTimezone(strings.TrimSpace(tz))) } @@ -54,7 +57,7 @@ func newClient(params paramscmd.API, opts ...api.Option) (*api.Client, error) { opts = append(opts, withSSLCert) } else if !params.DisableSSLVerify { - opts = append(opts, api.WithSSLCertPool(api.CACerts())) + opts = append(opts, api.WithSSLCertPool(api.CACerts(ctx))) } if params.ProxyURL != "" { @@ -66,7 +69,7 @@ func newClient(params paramscmd.API, opts ...api.Option) (*api.Client, error) { opts = append(opts, withProxy) if strings.Contains(params.ProxyURL, `\\`) { - withNTLMRetry, err := api.WithNTLMRequestRetry(params.ProxyURL) + withNTLMRetry, err := api.WithNTLMRequestRetry(ctx, params.ProxyURL) if err != nil { return nil, fmt.Errorf("failed to set up ntlm request retry option on api client: %w", err) } @@ -75,7 +78,7 @@ func newClient(params paramscmd.API, opts ...api.Option) (*api.Client, error) { } } - opts = append(opts, api.WithUserAgent(params.Plugin)) + opts = append(opts, api.WithUserAgent(ctx, params.Plugin)) return api.NewClient(params.URL, opts...), nil } diff --git a/cmd/configread/configread.go b/cmd/configread/configread.go index 65d1c36b..2145ef4e 100644 --- a/cmd/configread/configread.go +++ b/cmd/configread/configread.go @@ -1,6 +1,7 @@ package configread import ( + "context" "errors" "fmt" "strings" @@ -18,7 +19,7 @@ type Params struct { } // Run prints the value for the given config key. -func Run(v *viper.Viper) (int, error) { +func Run(_ context.Context, v *viper.Viper) (int, error) { output, err := Read(v) if err != nil { return exitcode.ErrConfigFileRead, fmt.Errorf( diff --git a/cmd/configwrite/configwrite.go b/cmd/configwrite/configwrite.go index 43d2871e..ee39f6e6 100644 --- a/cmd/configwrite/configwrite.go +++ b/cmd/configwrite/configwrite.go @@ -1,6 +1,7 @@ package configwrite import ( + "context" "errors" "fmt" "strings" @@ -19,8 +20,8 @@ type Params struct { } // Run loads wakatime config file and call Write(). -func Run(v *viper.Viper) (int, error) { - w, err := ini.NewWriter(v, ini.FilePath) +func Run(ctx context.Context, v *viper.Viper) (int, error) { + w, err := ini.NewWriter(ctx, v, ini.FilePath) if err != nil { return exitcode.ErrConfigFileParse, fmt.Errorf( "failed to parse config file: %s", @@ -28,7 +29,7 @@ func Run(v *viper.Viper) (int, error) { ) } - if err := Write(v, w); err != nil { + if err := Write(ctx, v, w); err != nil { return exitcode.ErrGeneric, fmt.Errorf( "failed to write to config file: %s", err, @@ -39,13 +40,13 @@ func Run(v *viper.Viper) (int, error) { } // Write writes value(s) to given config key(s) and persist on disk. -func Write(v *viper.Viper, w ini.Writer) error { +func Write(ctx context.Context, v *viper.Viper, w ini.Writer) error { params, err := LoadParams(v) if err != nil { return fmt.Errorf("failed to load command parameters: %w", err) } - return w.Write(params.Section, params.KeyValue) + return w.Write(ctx, params.Section, params.KeyValue) } // LoadParams loads needed data from the configuration file. diff --git a/cmd/configwrite/configwrite_test.go b/cmd/configwrite/configwrite_test.go index d96e0b7a..c602dc41 100644 --- a/cmd/configwrite/configwrite_test.go +++ b/cmd/configwrite/configwrite_test.go @@ -1,6 +1,7 @@ package configwrite_test import ( + "context" "errors" "fmt" "os" @@ -72,8 +73,10 @@ func TestWrite(t *testing.T) { defer tmpFile.Close() + ctx := context.Background() + v := viper.New() - ini, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + ini, err := ini.NewWriter(ctx, v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return tmpFile.Name(), nil }) @@ -82,7 +85,7 @@ func TestWrite(t *testing.T) { v.Set("config-section", "settings") v.Set("config-write", map[string]string{"debug": "false"}) - err = configwrite.Write(v, ini) + err = configwrite.Write(ctx, v, ini) require.NoError(t, err) err = ini.File.Reload() @@ -117,7 +120,7 @@ func TestWriteErr(t *testing.T) { v.Set("config-section", test.Section) v.Set("config-write", test.Value) - err := configwrite.Write(v, w) + err := configwrite.Write(context.Background(), v, w) require.Error(t, err) assert.Equal( @@ -133,7 +136,7 @@ func TestWriteErr(t *testing.T) { func TestWriteSaveErr(t *testing.T) { v := viper.New() w := &mockWriter{ - WriteFn: func(section string, keyValue map[string]string) error { + WriteFn: func(_ context.Context, section string, keyValue map[string]string) error { assert.Equal(t, "settings", section) assert.Equal(t, map[string]string{"debug": "false"}, keyValue) @@ -144,14 +147,14 @@ func TestWriteSaveErr(t *testing.T) { v.Set("config-section", "settings") v.Set("config-write", map[string]string{"debug": "false"}) - err := configwrite.Write(v, w) + err := configwrite.Write(context.Background(), v, w) assert.Error(t, err) } type mockWriter struct { - WriteFn func(section string, keyValue map[string]string) error + WriteFn func(ctx context.Context, section string, keyValue map[string]string) error } -func (m *mockWriter) Write(section string, keyValue map[string]string) error { - return m.WriteFn(section, keyValue) +func (m *mockWriter) Write(ctx context.Context, section string, keyValue map[string]string) error { + return m.WriteFn(ctx, section, keyValue) } diff --git a/cmd/fileexperts/fileexperts.go b/cmd/fileexperts/fileexperts.go index fea2295f..5c58b88b 100644 --- a/cmd/fileexperts/fileexperts.go +++ b/cmd/fileexperts/fileexperts.go @@ -1,6 +1,7 @@ package fileexperts import ( + "context" "fmt" apicmd "github.com/wakatime/wakatime-cli/cmd/api" @@ -18,8 +19,8 @@ import ( ) // Run executes the file-experts command. -func Run(v *viper.Viper) (int, error) { - output, err := FileExperts(v) +func Run(ctx context.Context, v *viper.Viper) (int, error) { + output, err := FileExperts(ctx, v) if err != nil { if errwaka, ok := err.(wakaerror.Error); ok { return errwaka.ExitCode(), fmt.Errorf("file experts fetch failed: %s", errwaka.Message()) @@ -31,29 +32,31 @@ func Run(v *viper.Viper) (int, error) { ) } - log.Debugln("successfully fetched file experts") + logger := log.Extract(ctx) + logger.Debugln("successfully fetched file experts") + fmt.Println(output) return exitcode.Success, nil } // FileExperts returns a rendered file experts of todays coding activity. -func FileExperts(v *viper.Viper) (string, error) { - params, err := LoadParams(v) +func FileExperts(ctx context.Context, v *viper.Viper) (string, error) { + params, err := LoadParams(ctx, v) if err != nil { return "", fmt.Errorf("failed to load command parameters: %w", err) } handleOpts := initHandleOptions(params) - apiClient, err := apicmd.NewClientWithoutAuth(params.API) + apiClient, err := apicmd.NewClientWithoutAuth(ctx, params.API) if err != nil { return "", fmt.Errorf("failed to initialize api client: %w", err) } handle := fileexperts.NewHandle(apiClient, handleOpts...) - results, err := handle([]heartbeat.Heartbeat{{Entity: params.Heartbeat.Entity}}) + results, err := handle(ctx, []heartbeat.Heartbeat{{Entity: params.Heartbeat.Entity}}) if err != nil { return "", err } @@ -75,17 +78,17 @@ func FileExperts(v *viper.Viper) (string, error) { // LoadParams loads file-expert config params from viper.Viper instance. Returns ErrAuth // if failed to retrieve api key. -func LoadParams(v *viper.Viper) (paramscmd.Params, error) { +func LoadParams(ctx context.Context, v *viper.Viper) (paramscmd.Params, error) { if v == nil { return paramscmd.Params{}, fmt.Errorf("viper instance unset") } - heartbeatParams, err := paramscmd.LoadHeartbeatParams(v) + heartbeatParams, err := paramscmd.LoadHeartbeatParams(ctx, v) if err != nil { return paramscmd.Params{}, fmt.Errorf("failed to load heartbeat params: %s", err) } - apiParams, err := paramscmd.LoadAPIParams(v) + apiParams, err := paramscmd.LoadAPIParams(ctx, v) if err != nil { return paramscmd.Params{}, fmt.Errorf("failed to load API parameters: %w", err) } diff --git a/cmd/fileexperts/fileexperts_test.go b/cmd/fileexperts/fileexperts_test.go index a18e05d9..423e4670 100644 --- a/cmd/fileexperts/fileexperts_test.go +++ b/cmd/fileexperts/fileexperts_test.go @@ -1,6 +1,7 @@ package fileexperts_test import ( + "context" "encoding/json" "fmt" "io" @@ -86,9 +87,8 @@ func TestFileExperts(t *testing.T) { v.Set("plugin", plugin) v.Set("project", "wakatime-cli") v.Set("entity", "testdata/main.go") - v.Set("file-experts", true) - output, err := fileexperts.FileExperts(v) + output, err := fileexperts.FileExperts(context.Background(), v) require.NoError(t, err) assert.Equal(t, "You: 40 mins | Karl: 21 mins", output) @@ -97,33 +97,31 @@ func TestFileExperts(t *testing.T) { } func TestFileExperts_NonExistingEntity(t *testing.T) { - tmpDir := t.TempDir() + ctx := context.Background() - logFile, err := os.CreateTemp(tmpDir, "") + logFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) defer logFile.Close() v := viper.New() + v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("api-url", "https://example.org") v.Set("entity", "nonexisting") - v.Set("file-experts", true) - v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("log-file", logFile.Name()) v.Set("verbose", true) - cmd.SetupLogging(v) + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) - defer func() { - if file, ok := log.Output().(*os.File); ok { - _ = file.Sync() - file.Close() - } else if handler, ok := log.Output().(io.Closer); ok { - handler.Close() - } - }() + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) + + _, err = fileexperts.FileExperts(ctx, v) + require.NoError(t, err) - _, err = fileexperts.FileExperts(v) + err = logFile.Sync() require.NoError(t, err) output, err := io.ReadAll(logFile) @@ -148,9 +146,8 @@ func TestFileExperts_ErrApi(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("api-url", testServerURL) v.Set("entity", "testdata/main.go") - v.Set("file-experts", true) - _, err := fileexperts.FileExperts(v) + _, err := fileexperts.FileExperts(context.Background(), v) require.Error(t, err) var errapi api.Err @@ -183,9 +180,8 @@ func TestFileExperts_ErrAuth(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("api-url", testServerURL) v.Set("entity", "testdata/main.go") - v.Set("file-experts", true) - _, err := fileexperts.FileExperts(v) + _, err := fileexperts.FileExperts(context.Background(), v) require.Error(t, err) var errauth api.ErrAuth @@ -217,9 +213,8 @@ func TestFileExperts_ErrBadRequest(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("api-url", testServerURL) v.Set("entity", "testdata/main.go") - v.Set("file-experts", true) - _, err := fileexperts.FileExperts(v) + _, err := fileexperts.FileExperts(context.Background(), v) require.Error(t, err) var errbadRequest api.ErrBadRequest diff --git a/cmd/heartbeat/heartbeat.go b/cmd/heartbeat/heartbeat.go index fce3b447..9f1267cd 100644 --- a/cmd/heartbeat/heartbeat.go +++ b/cmd/heartbeat/heartbeat.go @@ -1,6 +1,7 @@ package heartbeat import ( + "context" "errors" "fmt" "strings" @@ -30,13 +31,15 @@ import ( ) // Run executes the heartbeat command. -func Run(v *viper.Viper) (int, error) { - queueFilepath, err := offline.QueueFilepath(v) +func Run(ctx context.Context, v *viper.Viper) (int, error) { + logger := log.Extract(ctx) + + queueFilepath, err := offline.QueueFilepath(ctx, v) if err != nil { - log.Warnf("failed to load offline queue filepath: %s", err) + logger.Warnf("failed to load offline queue filepath: %s", err) } - err = SendHeartbeats(v, queueFilepath) + err = SendHeartbeats(ctx, v, queueFilepath) if err != nil { var errauth api.ErrAuth @@ -44,8 +47,8 @@ func Run(v *viper.Viper) (int, error) { // Save heartbeats to offline db even when api key invalid. // It avoids losing heartbeats when api key is invalid. if errors.As(err, &errauth) { - if err := offlinecmd.SaveHeartbeats(v, nil, queueFilepath); err != nil { - log.Errorf("failed to save heartbeats to offline queue: %s", err) + if err := offlinecmd.SaveHeartbeats(ctx, v, nil, queueFilepath); err != nil { + logger.Errorf("failed to save heartbeats to offline queue: %s", err) } return errauth.ExitCode(), fmt.Errorf("sending heartbeat(s) failed: %w", errauth) @@ -61,7 +64,7 @@ func Run(v *viper.Viper) (int, error) { ) } - log.Debugln("successfully sent heartbeat(s)") + logger.Debugln("successfully sent heartbeat(s)") return exitcode.Success, nil } @@ -69,29 +72,31 @@ func Run(v *viper.Viper) (int, error) { // SendHeartbeats sends a heartbeat to the wakatime api and includes additional // heartbeats from the offline queue, if available and offline sync is not // explicitly disabled. -func SendHeartbeats(v *viper.Viper, queueFilepath string) error { - params, err := LoadParams(v) +func SendHeartbeats(ctx context.Context, v *viper.Viper, queueFilepath string) error { + params, err := LoadParams(ctx, v) if err != nil { return fmt.Errorf("failed to load command parameters: %w", err) } - setLogFields(params) - log.Debugf("params: %s", params) + logger := log.Extract(ctx) + + setLogFields(ctx, params) + logger.Debugf("params: %s", params) if RateLimited(RateLimitParams{ Disabled: params.Offline.Disabled, LastSentAt: params.Offline.LastSentAt, Timeout: params.Offline.RateLimit, }) { - if err = offlinecmd.SaveHeartbeats(v, nil, queueFilepath); err == nil { + if err = offlinecmd.SaveHeartbeats(ctx, v, nil, queueFilepath); err == nil { return nil } // log offline db error then try to send heartbeats to API so they're not lost - log.Errorf("failed to save rate limited heartbeats: %s", err) + logger.Errorf("failed to save rate limited heartbeats: %s", err) } - heartbeats := buildHeartbeats(params) + heartbeats := buildHeartbeats(ctx, params) var chOfflineSave = make(chan bool) @@ -99,11 +104,11 @@ func SendHeartbeats(v *viper.Viper, queueFilepath string) error { if len(heartbeats) > offline.SendLimit { extraHeartbeats := heartbeats[offline.SendLimit:] - log.Debugf("save %d extra heartbeat(s) to offline queue", len(extraHeartbeats)) + logger.Debugf("save %d extra heartbeat(s) to offline queue", len(extraHeartbeats)) go func(done chan<- bool) { - if err := offlinecmd.SaveHeartbeats(v, extraHeartbeats, queueFilepath); err != nil { - log.Errorf("failed to save extra heartbeats to offline queue: %s", err) + if err := offlinecmd.SaveHeartbeats(ctx, v, extraHeartbeats, queueFilepath); err != nil { + logger.Errorf("failed to save extra heartbeats to offline queue: %s", err) } done <- true @@ -125,11 +130,11 @@ func SendHeartbeats(v *viper.Viper, queueFilepath string) error { HasProxy: params.API.ProxyURL != "", })) - apiClient, err := apicmd.NewClientWithoutAuth(params.API) + apiClient, err := apicmd.NewClientWithoutAuth(ctx, params.API) if err != nil { if !params.Offline.Disabled { - if err := offlinecmd.SaveHeartbeats(v, heartbeats, queueFilepath); err != nil { - log.Errorf("failed to save heartbeats to offline queue: %s", err) + if err := offlinecmd.SaveHeartbeats(ctx, v, heartbeats, queueFilepath); err != nil { + logger.Errorf("failed to save heartbeats to offline queue: %s", err) } } @@ -137,7 +142,7 @@ func SendHeartbeats(v *viper.Viper, queueFilepath string) error { } handle := heartbeat.NewHandle(apiClient, handleOpts...) - results, err := handle(heartbeats) + results, err := handle(ctx, heartbeats) // wait for offline queue save to finish if len(heartbeats) > offline.SendLimit { @@ -150,12 +155,12 @@ func SendHeartbeats(v *viper.Viper, queueFilepath string) error { for _, result := range results { if len(result.Errors) > 0 { - log.Warnln(strings.Join(result.Errors, " ")) + logger.Warnln(strings.Join(result.Errors, " ")) } } - if err := ResetRateLimit(v); err != nil { - log.Errorf("failed to reset rate limit: %s", err) + if err := ResetRateLimit(ctx, v); err != nil { + logger.Errorf("failed to reset rate limit: %s", err) } return nil @@ -163,17 +168,17 @@ func SendHeartbeats(v *viper.Viper, queueFilepath string) error { // LoadParams loads params from viper.Viper instance. Returns ErrAuth // if failed to retrieve api key. -func LoadParams(v *viper.Viper) (paramscmd.Params, error) { +func LoadParams(ctx context.Context, v *viper.Viper) (paramscmd.Params, error) { if v == nil { return paramscmd.Params{}, errors.New("viper instance unset") } - apiParams, err := paramscmd.LoadAPIParams(v) + apiParams, err := paramscmd.LoadAPIParams(ctx, v) if err != nil { return paramscmd.Params{}, fmt.Errorf("failed to load API parameters: %w", err) } - heartbeatParams, err := paramscmd.LoadHeartbeatParams(v) + heartbeatParams, err := paramscmd.LoadHeartbeatParams(ctx, v) if err != nil { return paramscmd.Params{}, fmt.Errorf("failed to load heartbeat params: %s", err) } @@ -181,7 +186,7 @@ func LoadParams(v *viper.Viper) (paramscmd.Params, error) { return paramscmd.Params{ API: apiParams, Heartbeat: heartbeatParams, - Offline: paramscmd.LoadOfflineParams(v), + Offline: paramscmd.LoadOfflineParams(ctx, v), }, nil } @@ -210,8 +215,8 @@ func RateLimited(params RateLimitParams) bool { } // ResetRateLimit updates the internal.heartbeats_last_sent_at timestamp. -func ResetRateLimit(v *viper.Viper) error { - w, err := ini.NewWriter(v, ini.InternalFilePath) +func ResetRateLimit(ctx context.Context, v *viper.Viper) error { + w, err := ini.NewWriter(ctx, v, ini.InternalFilePath) if err != nil { return fmt.Errorf("failed to parse config file: %s", err) } @@ -220,17 +225,17 @@ func ResetRateLimit(v *viper.Viper) error { "heartbeats_last_sent_at": time.Now().Format(ini.DateFormat), } - if err := w.Write("internal", keyValue); err != nil { + if err := w.Write(ctx, "internal", keyValue); err != nil { return fmt.Errorf("failed to write to internal config file: %s", err) } return nil } -func buildHeartbeats(params paramscmd.Params) []heartbeat.Heartbeat { +func buildHeartbeats(ctx context.Context, params paramscmd.Params) []heartbeat.Heartbeat { heartbeats := []heartbeat.Heartbeat{} - userAgent := heartbeat.UserAgent(params.API.Plugin) + userAgent := heartbeat.UserAgent(ctx, params.API.Plugin) heartbeats = append(heartbeats, heartbeat.New( params.Heartbeat.Project.BranchAlternate, @@ -256,7 +261,8 @@ func buildHeartbeats(params paramscmd.Params) []heartbeat.Heartbeat { )) if len(params.Heartbeat.ExtraHeartbeats) > 0 { - log.Debugf("include %d extra heartbeat(s) from stdin", len(params.Heartbeat.ExtraHeartbeats)) + logger := log.Extract(ctx) + logger.Debugf("include %d extra heartbeat(s) from stdin", len(params.Heartbeat.ExtraHeartbeats)) for _, h := range params.Heartbeat.ExtraHeartbeats { heartbeats = append(heartbeats, heartbeat.New( @@ -331,19 +337,19 @@ func initHandleOptions(params paramscmd.Params) []heartbeat.HandleOption { } } -func setLogFields(params paramscmd.Params) { - log.WithField("file", params.Heartbeat.Entity) - log.WithField("time", params.Heartbeat.Time) +func setLogFields(ctx context.Context, params paramscmd.Params) { + log.AddField(ctx, "file", params.Heartbeat.Entity) + log.AddField(ctx, "time", params.Heartbeat.Time) if params.API.Plugin != "" { - log.WithField("plugin", params.API.Plugin) + log.AddField(ctx, "plugin", params.API.Plugin) } if params.Heartbeat.LineNumber != nil { - log.WithField("lineno", params.Heartbeat.LineNumber) + log.AddField(ctx, "lineno", params.Heartbeat.LineNumber) } if params.Heartbeat.IsWrite != nil { - log.WithField("is_write", params.Heartbeat.IsWrite) + log.AddField(ctx, "is_write", params.Heartbeat.IsWrite) } } diff --git a/cmd/heartbeat/heartbeat_test.go b/cmd/heartbeat/heartbeat_test.go index a84dc982..e0832009 100644 --- a/cmd/heartbeat/heartbeat_test.go +++ b/cmd/heartbeat/heartbeat_test.go @@ -1,6 +1,7 @@ package heartbeat_test import ( + "context" "encoding/json" "fmt" "io" @@ -80,7 +81,12 @@ func TestSendHeartbeats(t *testing.T) { err = json.Unmarshal(body, &[]any{&entity}) require.NoError(t, err) - expectedBodyStr := fmt.Sprintf(string(expectedBody), entity.Entity, subfolders, heartbeat.UserAgent(plugin)) + expectedBodyStr := fmt.Sprintf( + string(expectedBody), + entity.Entity, + subfolders, + heartbeat.UserAgent(context.Background(), plugin), + ) assert.True(t, strings.HasSuffix(entity.Entity, "testdata/main.go")) assert.JSONEq(t, expectedBodyStr, string(body)) @@ -119,17 +125,13 @@ func TestSendHeartbeats(t *testing.T) { offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(context.Background(), v, offlineQueueFile.Name()) require.NoError(t, err) assert.Eventually(t, func() bool { return numCalls == 1 }, time.Second, 50*time.Millisecond) } func TestSendHeartbeats_RateLimited(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping because OS is windows.") - } - resetSingleton(t) testServerURL, router, tearDown := setupTestServer() @@ -178,7 +180,7 @@ func TestSendHeartbeats_RateLimited(t *testing.T) { v.Set("offline-queue-file", offlineQueueFile.Name()) v.Set("internal.heartbeats_last_sent_at", time.Now().Add(-time.Minute).Format(time.RFC3339)) - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(context.Background(), v, offlineQueueFile.Name()) require.NoError(t, err) assert.Zero(t, numCalls) @@ -214,7 +216,7 @@ func TestSendHeartbeats_WithFiltering_Exclude(t *testing.T) { offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(context.Background(), v, offlineQueueFile.Name()) require.NoError(t, err) assert.Equal(t, 0, numCalls) @@ -231,6 +233,8 @@ func TestSendHeartbeats_ExtraHeartbeats(t *testing.T) { numCalls int ) + ctx := context.Background() + projectFolder, err := filepath.Abs("../..") require.NoError(t, err) @@ -259,33 +263,35 @@ func TestSendHeartbeats_ExtraHeartbeats(t *testing.T) { assert.True(t, strings.HasSuffix(entities[i].Entity, "testdata/main.go")) } + userAgent := heartbeat.UserAgent(ctx, plugin) + expectedBodyStr := fmt.Sprintf( string(expectedBody), - entities[0].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[1].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[2].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[3].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[4].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[5].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[6].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[7].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[8].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[9].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[10].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[11].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[12].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[13].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[14].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[15].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[16].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[17].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[18].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[19].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[20].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[21].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[22].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[23].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[24].Entity, subfolders, heartbeat.UserAgent(plugin), + entities[0].Entity, subfolders, userAgent, + entities[1].Entity, subfolders, userAgent, + entities[2].Entity, subfolders, userAgent, + entities[3].Entity, subfolders, userAgent, + entities[4].Entity, subfolders, userAgent, + entities[5].Entity, subfolders, userAgent, + entities[6].Entity, subfolders, userAgent, + entities[7].Entity, subfolders, userAgent, + entities[8].Entity, subfolders, userAgent, + entities[9].Entity, subfolders, userAgent, + entities[10].Entity, subfolders, userAgent, + entities[11].Entity, subfolders, userAgent, + entities[12].Entity, subfolders, userAgent, + entities[13].Entity, subfolders, userAgent, + entities[14].Entity, subfolders, userAgent, + entities[15].Entity, subfolders, userAgent, + entities[16].Entity, subfolders, userAgent, + entities[17].Entity, subfolders, userAgent, + entities[18].Entity, subfolders, userAgent, + entities[19].Entity, subfolders, userAgent, + entities[20].Entity, subfolders, userAgent, + entities[21].Entity, subfolders, userAgent, + entities[22].Entity, subfolders, userAgent, + entities[23].Entity, subfolders, userAgent, + entities[24].Entity, subfolders, userAgent, ) assert.JSONEq(t, expectedBodyStr, string(body)) @@ -351,10 +357,10 @@ func TestSendHeartbeats_ExtraHeartbeats(t *testing.T) { defer offlineQueueFile.Close() - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) - offlineCount, err := offline.CountHeartbeats(offlineQueueFile.Name()) + offlineCount, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) assert.Equal(t, 1, offlineCount) @@ -373,6 +379,8 @@ func TestSendHeartbeats_ExtraHeartbeats_Sanitize(t *testing.T) { numCalls int ) + ctx := context.Background() + router.HandleFunc("/users/current/heartbeats.bulk", func(w http.ResponseWriter, _ *http.Request) { // send response w.WriteHeader(http.StatusCreated) @@ -436,10 +444,10 @@ func TestSendHeartbeats_ExtraHeartbeats_Sanitize(t *testing.T) { defer offlineQueueFile.Close() - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) - offlineCount, err := offline.CountHeartbeats(offlineQueueFile.Name()) + offlineCount, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) db, err := bolt.Open(offlineQueueFile.Name(), 0600, nil) @@ -508,6 +516,8 @@ func TestSendHeartbeats_NonExistingEntity(t *testing.T) { defer logFile.Close() + ctx := context.Background() + v := viper.New() v.SetDefault("sync-offline-activity", 1000) v.Set("api-url", "https://example.org") @@ -517,23 +527,19 @@ func TestSendHeartbeats_NonExistingEntity(t *testing.T) { v.Set("log-file", logFile.Name()) v.Set("verbose", true) - cmd.SetupLogging(v) + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) - defer func() { - if file, ok := log.Output().(*os.File); ok { - _ = file.Sync() - file.Close() - } else if handler, ok := log.Output().(io.Closer); ok { - handler.Close() - } - }() + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) f, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) defer f.Close() - err = cmdheartbeat.SendHeartbeats(v, f.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, f.Name()) require.NoError(t, err) output, err := io.ReadAll(logFile) @@ -553,6 +559,8 @@ func TestSendHeartbeats_IsUnsavedEntity(t *testing.T) { numCalls int ) + ctx := context.Background() + projectFolder, err := filepath.Abs("../..") require.NoError(t, err) @@ -577,11 +585,13 @@ func TestSendHeartbeats_IsUnsavedEntity(t *testing.T) { assert.True(t, strings.HasSuffix(entities[1].Entity, "missing-from-extra-heartbeats")) assert.True(t, strings.HasSuffix(entities[2].Entity, "main.go")) + userAgent := heartbeat.UserAgent(ctx, plugin) + expectedBodyStr := fmt.Sprintf( string(expectedBody), - entities[0].Entity, heartbeat.UserAgent(plugin), - entities[1].Entity, heartbeat.UserAgent(plugin), - entities[2].Entity, subfolders, heartbeat.UserAgent(plugin), + entities[0].Entity, userAgent, + entities[1].Entity, userAgent, + entities[2].Entity, subfolders, userAgent, ) assert.JSONEq(t, expectedBodyStr, string(body)) @@ -628,6 +638,8 @@ func TestSendHeartbeats_IsUnsavedEntity(t *testing.T) { logFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) + defer logFile.Close() + v := viper.New() v.SetDefault("sync-offline-activity", 1000) v.Set("api-url", testServerURL) @@ -650,24 +662,19 @@ func TestSendHeartbeats_IsUnsavedEntity(t *testing.T) { v.Set("log-file", logFile.Name()) v.Set("verbose", true) - cmd.SetupLogging(v) + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) offlineQueueFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer func() { - offlineQueueFile.Close() - logFile.Close() - - if file, ok := log.Output().(*os.File); ok { - _ = file.Sync() - file.Close() - } else if handler, ok := log.Output().(io.Closer); ok { - handler.Close() - } - }() + defer offlineQueueFile.Close() - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) output, err := io.ReadAll(logFile) @@ -687,6 +694,8 @@ func TestSendHeartbeats_NonExistingExtraHeartbeatsEntity(t *testing.T) { numCalls int ) + ctx := context.Background() + projectFolder, err := filepath.Abs("../..") require.NoError(t, err) @@ -710,10 +719,12 @@ func TestSendHeartbeats_NonExistingExtraHeartbeatsEntity(t *testing.T) { assert.True(t, strings.HasSuffix(entities[0].Entity, "testdata/main.go")) assert.True(t, strings.HasSuffix(entities[1].Entity, "testdata/main.py")) + userAgent := heartbeat.UserAgent(ctx, plugin) + expectedBodyStr := fmt.Sprintf( string(expectedBody), - entities[0].Entity, subfolders, heartbeat.UserAgent(plugin), - entities[1].Entity, subfolders, heartbeat.UserAgent(plugin), + entities[0].Entity, subfolders, userAgent, + entities[1].Entity, subfolders, userAgent, ) assert.JSONEq(t, expectedBodyStr, string(body)) @@ -760,6 +771,8 @@ func TestSendHeartbeats_NonExistingExtraHeartbeatsEntity(t *testing.T) { logFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) + defer logFile.Close() + v := viper.New() v.SetDefault("sync-offline-activity", 1000) v.Set("api-url", testServerURL) @@ -774,24 +787,19 @@ func TestSendHeartbeats_NonExistingExtraHeartbeatsEntity(t *testing.T) { v.Set("log-file", logFile.Name()) v.Set("verbose", true) - cmd.SetupLogging(v) + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) offlineQueueFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer func() { - offlineQueueFile.Close() - logFile.Close() - - if file, ok := log.Output().(*os.File); ok { - _ = file.Sync() - file.Close() - } else if handler, ok := log.Output().(io.Closer); ok { - handler.Close() - } - }() + defer offlineQueueFile.Close() - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) output, err := io.ReadAll(logFile) @@ -821,7 +829,7 @@ func TestSendHeartbeats_ErrAuth_UnsetAPIKey(t *testing.T) { defer offlineQueueFile.Close() - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(context.Background(), v, offlineQueueFile.Name()) require.Error(t, err) var errauth api.ErrAuth @@ -851,6 +859,8 @@ func TestSendHeartbeats_ErrBackoff(t *testing.T) { w.WriteHeader(http.StatusInternalServerError) }) + ctx := context.Background() + tmpDir := t.TempDir() logFile, err := os.CreateTemp(tmpDir, "") @@ -859,7 +869,6 @@ func TestSendHeartbeats_ErrBackoff(t *testing.T) { defer logFile.Close() v := viper.New() - v.Set("internal.backoff_at", time.Now().Add(10*time.Minute).Format(ini.DateFormat)) v.Set("internal.backoff_retries", "1") v.SetDefault("sync-offline-activity", 1000) @@ -869,29 +878,24 @@ func TestSendHeartbeats_ErrBackoff(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("log-file", logFile.Name()) - cmd.SetupLogging(v) + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) - defer func() { - if file, ok := log.Output().(*os.File); ok { - _ = file.Sync() - file.Close() - } else if handler, ok := log.Output().(io.Closer); ok { - handler.Close() - } - }() + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) defer offlineQueueFile.Close() - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) - require.Error(t, err) - assert.ErrorAs(t, err, &api.ErrBackoff{}) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) + require.ErrorAs(t, err, &api.ErrBackoff{}) assert.Equal(t, 0, numCalls) - offlineCount, err := offline.CountHeartbeats(offlineQueueFile.Name()) + offlineCount, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) assert.Equal(t, 1, offlineCount) @@ -916,6 +920,8 @@ func TestSendHeartbeats_ErrBackoff_Verbose(t *testing.T) { w.WriteHeader(http.StatusInternalServerError) }) + ctx := context.Background() + tmpDir := t.TempDir() logFile, err := os.CreateTemp(tmpDir, "") @@ -924,7 +930,6 @@ func TestSendHeartbeats_ErrBackoff_Verbose(t *testing.T) { defer logFile.Close() v := viper.New() - v.Set("internal.backoff_at", time.Now().Add(10*time.Minute).Format(ini.DateFormat)) v.Set("internal.backoff_retries", "1") v.SetDefault("sync-offline-activity", 1000) @@ -935,29 +940,25 @@ func TestSendHeartbeats_ErrBackoff_Verbose(t *testing.T) { v.Set("log-file", logFile.Name()) v.Set("verbose", true) - cmd.SetupLogging(v) + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() - defer func() { - if file, ok := log.Output().(*os.File); ok { - _ = file.Sync() - file.Close() - } else if handler, ok := log.Output().(io.Closer); ok { - handler.Close() - } - }() + ctx = log.ToContext(ctx, logger) offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) defer offlineQueueFile.Close() - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.Error(t, err) assert.ErrorAs(t, err, &api.ErrBackoff{}) assert.Equal(t, 0, numCalls) - offlineCount, err := offline.CountHeartbeats(offlineQueueFile.Name()) + offlineCount, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) assert.Equal(t, 1, offlineCount) @@ -979,6 +980,8 @@ func TestSendHeartbeats_ObfuscateProject(t *testing.T) { numCalls int ) + ctx := context.Background() + fp := setupTestGitBasic(t) router.HandleFunc("/users/current/heartbeats.bulk", func(w http.ResponseWriter, req *http.Request) { @@ -1006,10 +1009,10 @@ func TestSendHeartbeats_ObfuscateProject(t *testing.T) { err = json.Unmarshal(body, &[]any{&entity}) require.NoError(t, err) - lines, err := project.ReadFile(filepath.Join(fp, "wakatime-cli", ".wakatime-project"), 1) + lines, err := project.ReadFile(ctx, filepath.Join(fp, "wakatime-cli", ".wakatime-project"), 1) require.NoError(t, err) - expectedBodyStr := fmt.Sprintf(string(expectedBody), entity.Entity, lines[0], heartbeat.UserAgent(plugin)) + expectedBodyStr := fmt.Sprintf(string(expectedBody), entity.Entity, lines[0], heartbeat.UserAgent(ctx, plugin)) assert.True(t, strings.HasSuffix(entity.Entity, "src/pkg/file.go")) assert.JSONEq(t, expectedBodyStr, string(body)) @@ -1048,7 +1051,7 @@ func TestSendHeartbeats_ObfuscateProject(t *testing.T) { offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) assert.Eventually(t, func() bool { return numCalls == 1 }, time.Second, 50*time.Millisecond) @@ -1065,6 +1068,8 @@ func TestSendHeartbeats_ObfuscateProjectNotBranch(t *testing.T) { numCalls int ) + ctx := context.Background() + fp := setupTestGitBasic(t) router.HandleFunc("/users/current/heartbeats.bulk", func(w http.ResponseWriter, req *http.Request) { @@ -1092,10 +1097,10 @@ func TestSendHeartbeats_ObfuscateProjectNotBranch(t *testing.T) { err = json.Unmarshal(body, &[]any{&entity}) require.NoError(t, err) - lines, err := project.ReadFile(filepath.Join(fp, "wakatime-cli", ".wakatime-project"), 1) + lines, err := project.ReadFile(ctx, filepath.Join(fp, "wakatime-cli", ".wakatime-project"), 1) require.NoError(t, err) - expectedBodyStr := fmt.Sprintf(string(expectedBody), entity.Entity, lines[0], heartbeat.UserAgent(plugin)) + expectedBodyStr := fmt.Sprintf(string(expectedBody), entity.Entity, lines[0], heartbeat.UserAgent(ctx, plugin)) assert.True(t, strings.HasSuffix(entity.Entity, "src/pkg/file.go")) assert.JSONEq(t, expectedBodyStr, string(body)) @@ -1135,7 +1140,7 @@ func TestSendHeartbeats_ObfuscateProjectNotBranch(t *testing.T) { offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) - err = cmdheartbeat.SendHeartbeats(v, offlineQueueFile.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) assert.Eventually(t, func() bool { return numCalls == 1 }, time.Second, 50*time.Millisecond) @@ -1177,7 +1182,7 @@ func TestRateLimited_TimeoutZero(t *testing.T) { resetSingleton(t) p := cmdheartbeat.RateLimitParams{ - LastSentAt: time.Time{}, + Timeout: 0, } assert.False(t, cmdheartbeat.RateLimited(p)) @@ -1187,7 +1192,7 @@ func TestRateLimited_LastSentAtZero(t *testing.T) { resetSingleton(t) p := cmdheartbeat.RateLimitParams{ - Timeout: 0, + LastSentAt: time.Time{}, } assert.False(t, cmdheartbeat.RateLimited(p)) @@ -1206,17 +1211,19 @@ func TestResetRateLimit(t *testing.T) { defer tmpFileInternal.Close() + ctx := context.Background() + v := viper.New() v.Set("config", tmpFileInternal.Name()) v.Set("internal-config", tmpFileInternal.Name()) - writer, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + writer, err := ini.NewWriter(ctx, v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return tmpFileInternal.Name(), nil }) require.NoError(t, err) - err = cmdheartbeat.ResetRateLimit(v) + err = cmdheartbeat.ResetRateLimit(ctx, v) require.NoError(t, err) err = writer.File.Reload() diff --git a/cmd/logfile/logfile.go b/cmd/logfile/logfile.go index 520e2584..d9afd67a 100644 --- a/cmd/logfile/logfile.go +++ b/cmd/logfile/logfile.go @@ -1,6 +1,7 @@ package logfile import ( + "context" "fmt" "path/filepath" @@ -23,7 +24,7 @@ type Params struct { } // LoadParams loads needed data from the configuration file. -func LoadParams(v *viper.Viper) (Params, error) { +func LoadParams(ctx context.Context, v *viper.Viper) (Params, error) { params := Params{ Metrics: vipertools.FirstNonEmptyBool( v, @@ -55,7 +56,7 @@ func LoadParams(v *viper.Viper) (Params, error) { return params, nil } - folder, err := ini.WakaResourcesDir() + folder, err := ini.WakaResourcesDir(ctx) if err != nil { return Params{}, fmt.Errorf("failed getting resource directory: %s", err) } diff --git a/cmd/logfile/logfile_test.go b/cmd/logfile/logfile_test.go index 8824e209..95ed2f61 100644 --- a/cmd/logfile/logfile_test.go +++ b/cmd/logfile/logfile_test.go @@ -1,6 +1,7 @@ package logfile_test import ( + "context" "os" "path/filepath" "testing" @@ -28,6 +29,8 @@ func TestLoadParams(t *testing.T) { home, err := os.UserHomeDir() require.NoError(t, err) + ctx := context.Background() + tests := map[string]struct { EnvVar string ViperDebug bool @@ -134,7 +137,7 @@ func TestLoadParams(t *testing.T) { defer os.Unsetenv("WAKATIME_HOME") - params, err := logfile.LoadParams(v) + params, err := logfile.LoadParams(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, params) diff --git a/cmd/offline/offline.go b/cmd/offline/offline.go index 3f146c20..f7c68ed1 100644 --- a/cmd/offline/offline.go +++ b/cmd/offline/offline.go @@ -1,6 +1,7 @@ package offline import ( + "context" "errors" "fmt" @@ -21,15 +22,16 @@ import ( // SaveHeartbeats saves heartbeats to the offline db without trying to send to the API. // Used when we have more heartbeats than `offline.SendLimit`, when we couldn't send // heartbeats to the API, or the API returned an auth error. -func SaveHeartbeats(v *viper.Viper, heartbeats []heartbeat.Heartbeat, queueFilepath string) error { - params, err := loadParams(v) +func SaveHeartbeats(ctx context.Context, v *viper.Viper, heartbeats []heartbeat.Heartbeat, queueFilepath string) error { + params, err := loadParams(ctx, v) if err != nil { return fmt.Errorf("failed to load command parameters: %w", err) } - setLogFields(params) + logger := log.Extract(ctx) - log.Debugf("params: %s", params) + setLogFields(ctx, params) + logger.Debugf("params: %s", params) if params.Offline.Disabled { return errors.New("saving to offline db disabled") @@ -38,7 +40,7 @@ func SaveHeartbeats(v *viper.Viper, heartbeats []heartbeat.Heartbeat, queueFilep if heartbeats == nil { // We're not saving surplus extra heartbeats, so save // main heartbeat and all extra heartbeats to offline db - heartbeats = buildHeartbeats(params) + heartbeats = buildHeartbeats(ctx, params) } handleOpts := initHandleOptions(params) @@ -48,18 +50,20 @@ func SaveHeartbeats(v *viper.Viper, heartbeats []heartbeat.Heartbeat, queueFilep sender := offline.Noop{} handle := heartbeat.NewHandle(sender, handleOpts...) - _, _ = handle(heartbeats) + _, _ = handle(ctx, heartbeats) return nil } -func loadParams(v *viper.Viper) (paramscmd.Params, error) { - paramAPI, err := paramscmd.LoadAPIParams(v) +func loadParams(ctx context.Context, v *viper.Viper) (paramscmd.Params, error) { + logger := log.Extract(ctx) + + paramAPI, err := paramscmd.LoadAPIParams(ctx, v) if err != nil { - log.Warnf("failed to load API parameters: %s", err) + logger.Warnf("failed to load API parameters: %s", err) } - paramHeartbeat, err := paramscmd.LoadHeartbeatParams(v) + paramHeartbeat, err := paramscmd.LoadHeartbeatParams(ctx, v) if err != nil { return paramscmd.Params{}, fmt.Errorf("failed to load heartbeat parameters: %s", err) } @@ -67,14 +71,14 @@ func loadParams(v *viper.Viper) (paramscmd.Params, error) { return paramscmd.Params{ API: paramAPI, Heartbeat: paramHeartbeat, - Offline: paramscmd.LoadOfflineParams(v), + Offline: paramscmd.LoadOfflineParams(ctx, v), }, nil } -func buildHeartbeats(params paramscmd.Params) []heartbeat.Heartbeat { +func buildHeartbeats(ctx context.Context, params paramscmd.Params) []heartbeat.Heartbeat { heartbeats := []heartbeat.Heartbeat{} - userAgent := heartbeat.UserAgent(params.API.Plugin) + userAgent := heartbeat.UserAgent(ctx, params.API.Plugin) heartbeats = append(heartbeats, heartbeat.New( params.Heartbeat.Project.BranchAlternate, @@ -100,7 +104,8 @@ func buildHeartbeats(params paramscmd.Params) []heartbeat.Heartbeat { )) if len(params.Heartbeat.ExtraHeartbeats) > 0 { - log.Debugf("include %d extra heartbeat(s) from stdin", len(params.Heartbeat.ExtraHeartbeats)) + logger := log.Extract(ctx) + logger.Debugf("include %d extra heartbeat(s) from stdin", len(params.Heartbeat.ExtraHeartbeats)) for _, h := range params.Heartbeat.ExtraHeartbeats { heartbeats = append(heartbeats, heartbeat.New( @@ -171,20 +176,19 @@ func initHandleOptions(params paramscmd.Params) []heartbeat.HandleOption { } } -func setLogFields(params paramscmd.Params) { +func setLogFields(ctx context.Context, params paramscmd.Params) { + log.AddField(ctx, "file", params.Heartbeat.Entity) + log.AddField(ctx, "time", params.Heartbeat.Time) + if params.API.Plugin != "" { - log.WithField("plugin", params.API.Plugin) + log.AddField(ctx, "plugin", params.API.Plugin) } - log.WithField("time", params.Heartbeat.Time) - if params.Heartbeat.LineNumber != nil { - log.WithField("lineno", params.Heartbeat.LineNumber) + log.AddField(ctx, "lineno", params.Heartbeat.LineNumber) } if params.Heartbeat.IsWrite != nil { - log.WithField("is_write", params.Heartbeat.IsWrite) + log.AddField(ctx, "is_write", params.Heartbeat.IsWrite) } - - log.WithField("file", params.Heartbeat.Entity) } diff --git a/cmd/offlinecount/offlinecount.go b/cmd/offlinecount/offlinecount.go index 75f67f26..133e246f 100644 --- a/cmd/offlinecount/offlinecount.go +++ b/cmd/offlinecount/offlinecount.go @@ -1,6 +1,7 @@ package offlinecount import ( + "context" "fmt" "github.com/wakatime/wakatime-cli/pkg/exitcode" @@ -10,8 +11,8 @@ import ( ) // Run executes the offline-count command. -func Run(v *viper.Viper) (int, error) { - queueFilepath, err := offline.QueueFilepath(v) +func Run(ctx context.Context, v *viper.Viper) (int, error) { + queueFilepath, err := offline.QueueFilepath(ctx, v) if err != nil { return exitcode.ErrGeneric, fmt.Errorf( "failed to load offline queue filepath: %s", @@ -19,7 +20,7 @@ func Run(v *viper.Viper) (int, error) { ) } - count, err := offline.CountHeartbeats(queueFilepath) + count, err := offline.CountHeartbeats(ctx, queueFilepath) if err != nil { fmt.Println(err) return exitcode.ErrGeneric, fmt.Errorf("failed to count offline heartbeats: %w", err) diff --git a/cmd/offlinecount/offlinecount_test.go b/cmd/offlinecount/offlinecount_test.go index c6ce8688..7bc3d033 100644 --- a/cmd/offlinecount/offlinecount_test.go +++ b/cmd/offlinecount/offlinecount_test.go @@ -2,6 +2,7 @@ package offlinecount_test import ( "bytes" + "context" "fmt" "io" "os" @@ -41,7 +42,7 @@ func TestOfflineCount_Empty(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w - code, err := offlinecount.Run(v) + code, err := offlinecount.Run(context.Background(), v) assert.Equal(t, exitcode.Success, code) require.NoError(t, err) @@ -103,7 +104,7 @@ func TestOfflineCount(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w - code, err := offlinecount.Run(v) + code, err := offlinecount.Run(context.Background(), v) outC := make(chan string) // copy the output in a separate goroutine so printing can't block indefinitely diff --git a/cmd/offlineprint/offlineprint.go b/cmd/offlineprint/offlineprint.go index 300c338d..861ec0f9 100644 --- a/cmd/offlineprint/offlineprint.go +++ b/cmd/offlineprint/offlineprint.go @@ -2,6 +2,7 @@ package offlineprint import ( "bytes" + "context" "encoding/json" "fmt" @@ -14,8 +15,8 @@ import ( ) // Run executes the print-offline-heartbeats command. -func Run(v *viper.Viper) (int, error) { - queueFilepath, err := offline.QueueFilepath(v) +func Run(ctx context.Context, v *viper.Viper) (int, error) { + queueFilepath, err := offline.QueueFilepath(ctx, v) if err != nil { return exitcode.ErrGeneric, fmt.Errorf( "failed to load offline queue filepath: %s", @@ -23,9 +24,9 @@ func Run(v *viper.Viper) (int, error) { ) } - p := params.LoadOfflineParams(v) + p := params.LoadOfflineParams(ctx, v) - hh, err := offline.ReadHeartbeats(queueFilepath, p.PrintMax) + hh, err := offline.ReadHeartbeats(ctx, queueFilepath, p.PrintMax) if err != nil { fmt.Println(err) return exitcode.ErrGeneric, fmt.Errorf("failed to read offline heartbeats: %w", err) diff --git a/cmd/offlineprint/offlineprint_test.go b/cmd/offlineprint/offlineprint_test.go index a964ba75..e7feab9e 100644 --- a/cmd/offlineprint/offlineprint_test.go +++ b/cmd/offlineprint/offlineprint_test.go @@ -2,6 +2,7 @@ package offlineprint_test import ( "bytes" + "context" "fmt" "io" "os" @@ -55,7 +56,7 @@ func TestPrintOfflineHeartbeats(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w - code, err := offlineprint.Run(v) + code, err := offlineprint.Run(context.Background(), v) require.NoError(t, err) outC := make(chan string) @@ -95,7 +96,7 @@ func TestPrintOfflineHeartbeats_Empty(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w - code, err := offlineprint.Run(v) + code, err := offlineprint.Run(context.Background(), v) require.NoError(t, err) outC := make(chan string) diff --git a/cmd/offlinesync/offlinesync.go b/cmd/offlinesync/offlinesync.go index 35eb5e31..3f799c33 100644 --- a/cmd/offlinesync/offlinesync.go +++ b/cmd/offlinesync/offlinesync.go @@ -1,6 +1,7 @@ package offlinesync import ( + "context" "fmt" "os" @@ -18,33 +19,35 @@ import ( ) // RunWithoutRateLimiting executes the sync-offline-activity command without rate limiting. -func RunWithoutRateLimiting(v *viper.Viper) (int, error) { - return run(v) +func RunWithoutRateLimiting(ctx context.Context, v *viper.Viper) (int, error) { + return run(ctx, v) } // RunWithRateLimiting executes sync-offline-activity command with rate limiting enabled. -func RunWithRateLimiting(v *viper.Viper) (int, error) { - paramOffline := params.LoadOfflineParams(v) +func RunWithRateLimiting(ctx context.Context, v *viper.Viper) (int, error) { + paramOffline := params.LoadOfflineParams(ctx, v) + + logger := log.Extract(ctx) if cmdheartbeat.RateLimited(cmdheartbeat.RateLimitParams{ Disabled: paramOffline.Disabled, LastSentAt: paramOffline.LastSentAt, Timeout: paramOffline.RateLimit, }) { - log.Debugln("skip syncing offline activity to respect rate limit") + logger.Debugln("skip syncing offline activity to respect rate limit") return exitcode.Success, nil } - return run(v) + return run(ctx, v) } -func run(v *viper.Viper) (int, error) { - paramOffline := params.LoadOfflineParams(v) +func run(ctx context.Context, v *viper.Viper) (int, error) { + paramOffline := params.LoadOfflineParams(ctx, v) if paramOffline.Disabled { return exitcode.Success, nil } - queueFilepath, err := offline.QueueFilepath(v) + queueFilepath, err := offline.QueueFilepath(ctx, v) if err != nil { return exitcode.ErrGeneric, fmt.Errorf( "offline sync failed: failed to load offline queue filepath: %s", @@ -52,16 +55,18 @@ func run(v *viper.Viper) (int, error) { ) } - queueFilepathLegacy, err := offline.QueueFilepathLegacy(v) + logger := log.Extract(ctx) + + queueFilepathLegacy, err := offline.QueueFilepathLegacy(ctx, v) if err != nil { - log.Warnf("legacy offline sync failed: failed to load offline queue filepath: %s", err) + logger.Warnf("legacy offline sync failed: failed to load offline queue filepath: %s", err) } - if err = syncOfflineActivityLegacy(v, queueFilepathLegacy); err != nil { - log.Warnf("legacy offline sync failed: %s", err) + if err = syncOfflineActivityLegacy(ctx, v, queueFilepathLegacy); err != nil { + logger.Warnf("legacy offline sync failed: %s", err) } - if err = SyncOfflineActivity(v, queueFilepath); err != nil { + if err = SyncOfflineActivity(ctx, v, queueFilepath); err != nil { if errwaka, ok := err.(wakaerror.Error); ok { return errwaka.ExitCode(), fmt.Errorf("offline sync failed: %s", errwaka.Message()) } @@ -72,14 +77,14 @@ func run(v *viper.Viper) (int, error) { ) } - log.Debugln("successfully synced offline activity") + logger.Debugln("successfully synced offline activity") return exitcode.Success, nil } // syncOfflineActivityLegacy syncs the old offline activity by sending heartbeats // from the legacy offline queue to the WakaTime API. -func syncOfflineActivityLegacy(v *viper.Viper, queueFilepath string) error { +func syncOfflineActivityLegacy(ctx context.Context, v *viper.Viper, queueFilepath string) error { if queueFilepath == "" { return nil } @@ -88,14 +93,14 @@ func syncOfflineActivityLegacy(v *viper.Viper, queueFilepath string) error { return nil } - paramOffline := params.LoadOfflineParams(v) + paramOffline := params.LoadOfflineParams(ctx, v) - paramAPI, err := params.LoadAPIParams(v) + paramAPI, err := params.LoadAPIParams(ctx, v) if err != nil { return fmt.Errorf("failed to load API parameters: %w", err) } - apiClient, err := cmdapi.NewClientWithoutAuth(paramAPI) + apiClient, err := cmdapi.NewClientWithoutAuth(ctx, paramAPI) if err != nil { return fmt.Errorf("failed to initialize api client: %w", err) } @@ -108,13 +113,15 @@ func syncOfflineActivityLegacy(v *viper.Viper, queueFilepath string) error { }), ) - _, err = handle(nil) + _, err = handle(ctx, nil) if err != nil { return err } + logger := log.Extract(ctx) + if err := os.Remove(queueFilepath); err != nil { - log.Warnf("failed to delete legacy offline file: %s", err) + logger.Warnf("failed to delete legacy offline file: %s", err) } return nil @@ -122,18 +129,18 @@ func syncOfflineActivityLegacy(v *viper.Viper, queueFilepath string) error { // SyncOfflineActivity syncs offline activity by sending heartbeats // from the offline queue to the WakaTime API. -func SyncOfflineActivity(v *viper.Viper, queueFilepath string) error { - paramAPI, err := params.LoadAPIParams(v) +func SyncOfflineActivity(ctx context.Context, v *viper.Viper, queueFilepath string) error { + paramAPI, err := params.LoadAPIParams(ctx, v) if err != nil { return fmt.Errorf("failed to load API parameters: %w", err) } - apiClient, err := cmdapi.NewClientWithoutAuth(paramAPI) + apiClient, err := cmdapi.NewClientWithoutAuth(ctx, paramAPI) if err != nil { return fmt.Errorf("failed to initialize api client: %w", err) } - paramOffline := params.LoadOfflineParams(v) + paramOffline := params.LoadOfflineParams(ctx, v) handle := heartbeat.NewHandle(apiClient, offline.WithSync(queueFilepath, paramOffline.SyncMax), @@ -143,13 +150,15 @@ func SyncOfflineActivity(v *viper.Viper, queueFilepath string) error { }), ) - _, err = handle(nil) + _, err = handle(ctx, nil) if err != nil { return err } - if err := cmdheartbeat.ResetRateLimit(v); err != nil { - log.Errorf("failed to reset rate limit: %s", err) + logger := log.Extract(ctx) + + if err := cmdheartbeat.ResetRateLimit(ctx, v); err != nil { + logger.Errorf("failed to reset rate limit: %s", err) } return nil diff --git a/cmd/offlinesync/offlinesync_internal_test.go b/cmd/offlinesync/offlinesync_internal_test.go index a8412e6a..477e4dc2 100644 --- a/cmd/offlinesync/offlinesync_internal_test.go +++ b/cmd/offlinesync/offlinesync_internal_test.go @@ -1,6 +1,7 @@ package offlinesync import ( + "context" "fmt" "io" "net/http" @@ -101,7 +102,7 @@ func TestSyncOfflineActivityLegacy(t *testing.T) { v.Set("sync-offline-activity", 100) v.Set("plugin", plugin) - err = syncOfflineActivityLegacy(v, f.Name()) + err = syncOfflineActivityLegacy(context.Background(), v, f.Name()) require.NoError(t, err) assert.NoFileExists(t, f.Name()) diff --git a/cmd/offlinesync/offlinesync_test.go b/cmd/offlinesync/offlinesync_test.go index 51f2f8f7..8fabb704 100644 --- a/cmd/offlinesync/offlinesync_test.go +++ b/cmd/offlinesync/offlinesync_test.go @@ -1,6 +1,7 @@ package offlinesync_test import ( + "context" "encoding/json" "fmt" "io" @@ -108,7 +109,7 @@ func TestRunWithRateLimiting(t *testing.T) { v.Set("sync-offline-activity", 100) v.Set("plugin", plugin) - code, err := offlinesync.RunWithRateLimiting(v) + code, err := offlinesync.RunWithRateLimiting(context.Background(), v) require.NoError(t, err) assert.Equal(t, exitcode.Success, code) @@ -200,7 +201,7 @@ func TestRunWithoutRateLimiting(t *testing.T) { v.Set("sync-offline-activity", 100) v.Set("plugin", plugin) - code, err := offlinesync.RunWithoutRateLimiting(v) + code, err := offlinesync.RunWithoutRateLimiting(context.Background(), v) require.NoError(t, err) assert.Equal(t, exitcode.Success, code) @@ -215,7 +216,7 @@ func TestRunWithRateLimiting_RateLimited(t *testing.T) { v.Set("heartbeat-rate-limit-seconds", 500) v.Set("internal.heartbeats_last_sent_at", time.Now().Add(-time.Minute).Format(time.RFC3339)) - code, err := offlinesync.RunWithRateLimiting(v) + code, err := offlinesync.RunWithRateLimiting(context.Background(), v) require.NoError(t, err) assert.Equal(t, exitcode.Success, code) @@ -305,7 +306,7 @@ func TestSyncOfflineActivity(t *testing.T) { v.Set("sync-offline-activity", 100) v.Set("plugin", plugin) - err = offlinesync.SyncOfflineActivity(v, f.Name()) + err = offlinesync.SyncOfflineActivity(context.Background(), v, f.Name()) require.NoError(t, err) assert.Eventually(t, func() bool { return numCalls == 1 }, time.Second, 50*time.Millisecond) @@ -396,7 +397,7 @@ func TestSyncOfflineActivity_MultipleApiKey(t *testing.T) { v.Set("sync-offline-activity", 100) v.Set("plugin", plugin) - err = offlinesync.SyncOfflineActivity(v, f.Name()) + err = offlinesync.SyncOfflineActivity(context.Background(), v, f.Name()) require.NoError(t, err) assert.Eventually(t, func() bool { return numCalls == 1 }, time.Second, 50*time.Millisecond) diff --git a/cmd/params/params.go b/cmd/params/params.go index e6627b97..457dd01f 100644 --- a/cmd/params/params.go +++ b/cmd/params/params.go @@ -167,12 +167,14 @@ type ( // LoadAPIParams loads API params from viper.Viper instance. Returns ErrAuth // if failed to retrieve api key. -func LoadAPIParams(v *viper.Viper) (API, error) { - apiKey, err := LoadAPIKey(v) +func LoadAPIParams(ctx context.Context, v *viper.Viper) (API, error) { + apiKey, err := LoadAPIKey(ctx, v) if err != nil { return API{}, err } + logger := log.Extract(ctx) + var apiKeyPatterns []apikey.MapPattern apiKeyMap := vipertools.GetStringMapString(v, "project_api_key") @@ -185,7 +187,7 @@ func LoadAPIParams(v *viper.Viper) (API, error) { compiled, err := regex.Compile(k) if err != nil { - log.Warnf("failed to compile project_api_key regex pattern %q", k) + logger.Warnf("failed to compile project_api_key regex pattern %q", k) continue } @@ -228,7 +230,7 @@ func LoadAPIParams(v *viper.Viper) (API, error) { parsed, err := safeTimeParse(ini.DateFormat, backoffAtStr) // nolint:gocritic if err != nil { - log.Warnf("failed to parse backoff_at: %s", err) + logger.Warnf("failed to parse backoff_at: %s", err) } else if parsed.After(time.Now()) { backoffAt = time.Now() } else { @@ -242,7 +244,7 @@ func LoadAPIParams(v *viper.Viper) (API, error) { if backoffRetriesStr != "" { parsed, err := strconv.Atoi(backoffRetriesStr) if err != nil { - log.Warnf("failed to parse backoff_retries: %s", err) + logger.Warnf("failed to parse backoff_retries: %s", err) } else { backoffRetries = parsed } @@ -258,7 +260,7 @@ func LoadAPIParams(v *viper.Viper) (API, error) { if hostname == "" { hostname, err = os.Hostname() if err != nil { - log.Warnf("failed to retrieve hostname from system: %s", err) + logger.Warnf("failed to retrieve hostname from system: %s", err) } } @@ -277,7 +279,7 @@ func LoadAPIParams(v *viper.Viper) (API, error) { proxyEnvURL, err := proxyEnv.ProxyFunc()(apiURL) if err != nil { - log.Warnf("failed to get proxy url from environment for api url: %s", err) + logger.Warnf("failed to get proxy url from environment for api url: %s", err) } // try use proxy from environment if no custom proxy is set @@ -315,7 +317,7 @@ func LoadAPIParams(v *viper.Viper) (API, error) { } // LoadAPIKey loads a valid default WakaTime API Key or returns an error. -func LoadAPIKey(v *viper.Viper) (string, error) { +func LoadAPIKey(ctx context.Context, v *viper.Viper) (string, error) { apiKey := vipertools.FirstNonEmptyString(v, "key", "settings.api_key", "settings.apikey") if apiKey != "" { if !apiKeyRegex.MatchString(apiKey) { @@ -330,12 +332,14 @@ func LoadAPIKey(v *viper.Viper) (string, error) { return "", api.ErrAuth{Err: fmt.Errorf("failed to read api key from vault: %s", err)} } + logger := log.Extract(ctx) + if apiKey != "" { if !apiKeyRegex.MatchString(apiKey) { return "", api.ErrAuth{Err: errors.New("invalid api key format")} } - log.Debugln("loaded api key from vault") + logger.Debugln("loaded api key from vault") return apiKey, nil } @@ -346,7 +350,7 @@ func LoadAPIKey(v *viper.Viper) (string, error) { return "", api.ErrAuth{Err: errors.New("invalid api key format")} } - log.Debugln("loaded api key from env var") + logger.Debugln("loaded api key from env var") return apiKey, nil } @@ -359,7 +363,7 @@ func LoadAPIKey(v *viper.Viper) (string, error) { } // LoadHeartbeatParams loads heartbeats params from viper.Viper instance. -func LoadHeartbeatParams(v *viper.Viper) (Heartbeat, error) { +func LoadHeartbeatParams(ctx context.Context, v *viper.Viper) (Heartbeat, error) { var category heartbeat.Category if categoryStr := vipertools.GetString(v, "category"); categoryStr != "" { @@ -400,7 +404,7 @@ func LoadHeartbeatParams(v *viper.Viper) (Heartbeat, error) { var extraHeartbeats []heartbeat.Heartbeat if v.GetBool("extra-heartbeats") { - extraHeartbeats = readExtraHeartbeats() + extraHeartbeats = readExtraHeartbeats(ctx) } var isWrite *bool @@ -433,7 +437,7 @@ func LoadHeartbeatParams(v *viper.Viper) (Heartbeat, error) { timeSecs = float64(time.Now().UnixNano()) / 1000000000 } - projectParams, err := loadProjectParams(v) + projectParams, err := loadProjectParams(ctx, v) if err != nil { return Heartbeat{}, fmt.Errorf("failed to parse project params: %s", err) } @@ -465,17 +469,19 @@ func LoadHeartbeatParams(v *viper.Viper) (Heartbeat, error) { LinesInFile: linesInFile, LocalFile: vipertools.GetString(v, "local-file"), Time: timeSecs, - Filter: loadFilterParams(v), + Filter: loadFilterParams(ctx, v), Project: projectParams, Sanitize: sanitizeParams, }, nil } -func loadFilterParams(v *viper.Viper) FilterParams { +func loadFilterParams(ctx context.Context, v *viper.Viper) FilterParams { exclude := v.GetStringSlice("exclude") exclude = append(exclude, v.GetStringSlice("settings.exclude")...) exclude = append(exclude, v.GetStringSlice("settings.ignore")...) + logger := log.Extract(ctx) + var excludePatterns []regex.Regex for _, s := range exclude { @@ -486,7 +492,7 @@ func loadFilterParams(v *viper.Viper) FilterParams { compiled, err := regex.Compile(s) if err != nil { - log.Warnf("failed to compile exclude regex pattern %q", s) + logger.Warnf("failed to compile exclude regex pattern %q", s) continue } @@ -506,7 +512,7 @@ func loadFilterParams(v *viper.Viper) FilterParams { compiled, err := regex.Compile(s) if err != nil { - log.Warnf("failed to compile include regex pattern %q", s) + logger.Warnf("failed to compile include regex pattern %q", s) continue } @@ -595,7 +601,7 @@ func loadSanitizeParams(v *viper.Viper) (SanitizeParams, error) { }, nil } -func loadProjectParams(v *viper.Viper) (ProjectParams, error) { +func loadProjectParams(ctx context.Context, v *viper.Viper) (ProjectParams, error) { submodulesDisabled, err := parseBoolOrRegexList(vipertools.GetString(v, "git.submodules_disabled")) if err != nil { return ProjectParams{}, fmt.Errorf( @@ -607,15 +613,17 @@ func loadProjectParams(v *viper.Viper) (ProjectParams, error) { return ProjectParams{ Alternate: vipertools.GetString(v, "alternate-project"), BranchAlternate: vipertools.GetString(v, "alternate-branch"), - MapPatterns: loadProjectMapPatterns(v, "projectmap"), + MapPatterns: loadProjectMapPatterns(ctx, v, "projectmap"), Override: vipertools.GetString(v, "project"), ProjectFromGitRemote: v.GetBool("git.project_from_git_remote"), SubmodulesDisabled: submodulesDisabled, - SubmoduleMapPatterns: loadProjectMapPatterns(v, "git_submodule_projectmap"), + SubmoduleMapPatterns: loadProjectMapPatterns(ctx, v, "git_submodule_projectmap"), }, nil } -func loadProjectMapPatterns(v *viper.Viper, prefix string) []project.MapPattern { +func loadProjectMapPatterns(ctx context.Context, v *viper.Viper, prefix string) []project.MapPattern { + logger := log.Extract(ctx) + var mapPatterns []project.MapPattern values := vipertools.GetStringMapString(v, prefix) @@ -628,7 +636,7 @@ func loadProjectMapPatterns(v *viper.Viper, prefix string) []project.MapPattern compiled, err := regex.Compile(k) if err != nil { - log.Warnf("failed to compile projectmap regex pattern %q", k) + logger.Warnf("failed to compile projectmap regex pattern %q", k) continue } @@ -642,22 +650,24 @@ func loadProjectMapPatterns(v *viper.Viper, prefix string) []project.MapPattern } // LoadOfflineParams loads offline params from viper.Viper instance. -func LoadOfflineParams(v *viper.Viper) Offline { +func LoadOfflineParams(ctx context.Context, v *viper.Viper) Offline { disabled := vipertools.FirstNonEmptyBool(v, "disable-offline", "disableoffline") if b := v.GetBool("settings.offline"); v.IsSet("settings.offline") { disabled = !b } + logger := log.Extract(ctx) + rateLimit, _ := vipertools.FirstNonEmptyInt(v, "heartbeat-rate-limit-seconds", "settings.heartbeat_rate_limit_seconds") if rateLimit < 0 { - log.Warnf("argument --heartbeat-rate-limit-seconds must be zero or a positive integer number, got %d", rateLimit) + logger.Warnf("argument --heartbeat-rate-limit-seconds must be zero or a positive integer number, got %d", rateLimit) rateLimit = 0 } syncMax := v.GetInt("sync-offline-activity") if syncMax < 0 { - log.Warnf("argument --sync-offline-activity must be zero or a positive integer number, got %d", syncMax) + logger.Warnf("argument --sync-offline-activity must be zero or a positive integer number, got %d", syncMax) syncMax = 0 } @@ -669,7 +679,7 @@ func LoadOfflineParams(v *viper.Viper) Offline { parsed, err := safeTimeParse(ini.DateFormat, lastSentAtStr) // nolint:gocritic if err != nil { - log.Warnf("failed to parse heartbeats_last_sent_at: %s", err) + logger.Warnf("failed to parse heartbeats_last_sent_at: %s", err) } else if parsed.After(time.Now()) { lastSentAt = time.Now() } else { @@ -767,18 +777,20 @@ var extraHeartbeatsCache []heartbeat.Heartbeat // nolint:gochecknoglobals // Once prevents reading from stdin twice. var Once sync.Once // nolint:gochecknoglobals -func readExtraHeartbeats() []heartbeat.Heartbeat { +func readExtraHeartbeats(ctx context.Context) []heartbeat.Heartbeat { Once.Do(func() { + logger := log.Extract(ctx) + in := bufio.NewReader(os.Stdin) input, err := in.ReadString('\n') if err != nil && err != io.EOF { - log.Debugf("failed to read data from stdin: %s", err) + logger.Debugf("failed to read data from stdin: %s", err) } - heartbeats, err := parseExtraHeartbeats(input) + heartbeats, err := parseExtraHeartbeats(ctx, input) if err != nil { - log.Errorf("failed parsing: %s", err) + logger.Errorf("failed parsing: %s", err) } extraHeartbeatsCache = heartbeats @@ -787,9 +799,11 @@ func readExtraHeartbeats() []heartbeat.Heartbeat { return extraHeartbeatsCache } -func parseExtraHeartbeats(data string) ([]heartbeat.Heartbeat, error) { +func parseExtraHeartbeats(ctx context.Context, data string) ([]heartbeat.Heartbeat, error) { + logger := log.Extract(ctx) + if data == "" { - log.Debugln("skipping extra heartbeats, as no data was provided") + logger.Debugln("skipping extra heartbeats, as no data was provided") return nil, nil } @@ -1149,9 +1163,9 @@ func parseBoolOrRegexList(s string) ([]regex.Regex, error) { switch { case s == "": case strings.ToLower(s) == "false": - patterns = []regex.Regex{matchNoneRegex} + patterns = []regex.Regex{regex.NewRegexpWrap(matchNoneRegex)} case strings.ToLower(s) == "true": - patterns = []regex.Regex{matchAllRegex} + patterns = []regex.Regex{regex.NewRegexpWrap(matchAllRegex)} default: splitted := strings.Split(s, "\n") for _, s := range splitted { diff --git a/cmd/params/params_internal_test.go b/cmd/params/params_internal_test.go index 25b7866f..2d72e323 100644 --- a/cmd/params/params_internal_test.go +++ b/cmd/params/params_internal_test.go @@ -23,31 +23,31 @@ func TestParseBoolOrRegexList(t *testing.T) { }, "false string": { Input: "false", - Expected: []regex.Regex{regexp.MustCompile("a^")}, + Expected: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("a^"))}, }, "true string": { Input: "true", - Expected: []regex.Regex{regexp.MustCompile(".*")}, + Expected: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, "valid regex": { Input: "\t.?\n\t\n \n\t\twakatime.? \t\n", Expected: []regex.Regex{ - regexp.MustCompile(".?"), - regexp.MustCompile("wakatime.?"), + regex.NewRegexpWrap(regexp.MustCompile(".?")), + regex.NewRegexpWrap(regexp.MustCompile("wakatime.?")), }, }, "valid regex with windows style": { Input: "\t.?\r\n\t\t\twakatime.? \t\r\n", Expected: []regex.Regex{ - regexp.MustCompile(".?"), - regexp.MustCompile("wakatime.?"), + regex.NewRegexpWrap(regexp.MustCompile(".?")), + regex.NewRegexpWrap(regexp.MustCompile("wakatime.?")), }, }, "valid regex with old mac style": { Input: "\t.?\r\t\t\twakatime.? \t\r", Expected: []regex.Regex{ - regexp.MustCompile(".?"), - regexp.MustCompile("wakatime.?"), + regex.NewRegexpWrap(regexp.MustCompile(".?")), + regex.NewRegexpWrap(regexp.MustCompile("wakatime.?")), }, }, } diff --git a/cmd/params/params_test.go b/cmd/params/params_test.go index 9d29ce74..484b690b 100644 --- a/cmd/params/params_test.go +++ b/cmd/params/params_test.go @@ -2,6 +2,7 @@ package params_test import ( "bytes" + "context" "fmt" "io" "os" @@ -13,6 +14,7 @@ import ( "testing" "time" + "github.com/wakatime/wakatime-cli/cmd" cmdparams "github.com/wakatime/wakatime-cli/cmd/params" "github.com/wakatime/wakatime-cli/pkg/api" "github.com/wakatime/wakatime-cli/pkg/apikey" @@ -22,11 +24,11 @@ import ( "github.com/wakatime/wakatime-cli/pkg/output" "github.com/wakatime/wakatime-cli/pkg/project" "github.com/wakatime/wakatime-cli/pkg/regex" + "gopkg.in/ini.v1" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/ini.v1" ) func TestLoadHeartbeatParams_AlternateProject(t *testing.T) { @@ -34,7 +36,7 @@ func TestLoadHeartbeatParams_AlternateProject(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("alternate-project", "web") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "web", params.Project.Alternate) @@ -44,13 +46,15 @@ func TestLoadHeartbeatParams_AlternateProject_Unset(t *testing.T) { v := viper.New() v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Empty(t, params.Project.Alternate) } func TestLoadHeartbeatParams_Category(t *testing.T) { + ctx := context.Background() + tests := map[string]heartbeat.Category{ "advising": heartbeat.AdvisingCategory, "browsing": heartbeat.BrowsingCategory, @@ -78,7 +82,7 @@ func TestLoadHeartbeatParams_Category(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("category", name) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, category, params.Category) @@ -90,7 +94,7 @@ func TestLoadHeartbeatParams_Category_Default(t *testing.T) { v := viper.New() v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, heartbeat.CodingCategory, params.Category) @@ -101,7 +105,7 @@ func TestLoadHeartbeatParams_Category_Invalid(t *testing.T) { v.SetDefault("sync-offline-activity", 1000) v.Set("category", "invalid") - _, err := cmdparams.LoadHeartbeatParams(v) + _, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.Error(t, err) assert.Equal(t, "failed to parse category: invalid category \"invalid\"", err.Error()) @@ -112,7 +116,7 @@ func TestLoadHeartbeatParams_CursorPosition(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("cursorpos", 42) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, 42, *params.CursorPosition) @@ -123,7 +127,7 @@ func TestLoadHeartbeatParams_CursorPosition_Zero(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("cursorpos", 0) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Zero(t, *params.CursorPosition) @@ -134,7 +138,7 @@ func TestLoadHeartbeatParams_CursorPosition_Unset(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("key", "00000000-0000-4000-8000-000000000000") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Nil(t, params.CursorPosition) @@ -145,7 +149,7 @@ func TestLoadHeartbeatParams_Entity_EntityFlagTakesPrecedence(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("file", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "/path/to/file", params.Entity) @@ -158,7 +162,7 @@ func TestLoadHeartbeatParams_Entity_FileFlag(t *testing.T) { home, err := os.UserHomeDir() require.NoError(t, err) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, filepath.Join(home, "/path/to/file"), params.Entity) @@ -167,13 +171,15 @@ func TestLoadHeartbeatParams_Entity_FileFlag(t *testing.T) { func TestLoadHeartbeatParams_Entity_Unset(t *testing.T) { v := viper.New() - _, err := cmdparams.LoadHeartbeatParams(v) + _, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.Error(t, err) assert.Equal(t, "failed to retrieve entity", err.Error()) } func TestLoadHeartbeatParams_EntityType(t *testing.T) { + ctx := context.Background() + tests := map[string]heartbeat.EntityType{ "file": heartbeat.FileType, "domain": heartbeat.DomainType, @@ -186,7 +192,7 @@ func TestLoadHeartbeatParams_EntityType(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("entity-type", name) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, entityType, params.EntityType) @@ -198,7 +204,7 @@ func TestLoadHeartbeatParams_EntityType_Default(t *testing.T) { v := viper.New() v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, heartbeat.FileType, params.EntityType) @@ -209,7 +215,7 @@ func TestLoadHeartbeatParams_EntityType_Invalid(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("entity-type", "invalid") - _, err := cmdparams.LoadHeartbeatParams(v) + _, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.Error(t, err) assert.Equal( @@ -249,7 +255,7 @@ func TestLoadHeartbeatParams_ExtraHeartbeats(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("extra-heartbeats", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Len(t, params.ExtraHeartbeats, 2) @@ -323,7 +329,7 @@ func TestLoadHeartbeatParams_ExtraHeartbeats_WithStringValues(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("extra-heartbeats", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Len(t, params.ExtraHeartbeats, 2) @@ -393,7 +399,7 @@ func TestLoadHeartbeatParams_ExtraHeartbeats_WithEOF(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("extra-heartbeats", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Len(t, params.ExtraHeartbeats, 2) @@ -445,10 +451,7 @@ func TestLoadHeartbeatParams_ExtraHeartbeats_NoData(t *testing.T) { w.Close() }() - logs := bytes.NewBuffer(nil) - - teardownLogCapture := captureLogs(logs) - defer teardownLogCapture() + ctx := context.Background() origStdin := os.Stdin @@ -465,17 +468,34 @@ func TestLoadHeartbeatParams_ExtraHeartbeats_NoData(t *testing.T) { w.Close() }() + logFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer logFile.Close() + v := viper.New() v.Set("entity", "/path/to/file") v.Set("extra-heartbeats", true) + v.Set("log-file", logFile.Name()) + v.Set("verbose", true) + + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Empty(t, params.ExtraHeartbeats) - assert.Contains(t, logs.String(), "skipping extra heartbeats, as no data was provided") - assert.NotContains(t, logs.String(), "failed to read extra heartbeats: failed parsing") + output, err := io.ReadAll(logFile) + require.NoError(t, err) + + assert.Contains(t, string(output), "skipping extra heartbeats, as no data was provided") + assert.NotContains(t, string(output), "failed to read extra heartbeats: failed parsing") } func TestLoadHeartbeat_GuessLanguage_FlagTakesPrecedence(t *testing.T) { @@ -484,7 +504,7 @@ func TestLoadHeartbeat_GuessLanguage_FlagTakesPrecedence(t *testing.T) { v.Set("guess-language", true) v.Set("settings.guess_language", false) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.GuessLanguage) @@ -495,7 +515,7 @@ func TestLoadHeartbeat_GuessLanguage_FromConfig(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("settings.guess_language", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.GuessLanguage) @@ -505,7 +525,7 @@ func TestLoadHeartbeat_GuessLanguage_Default(t *testing.T) { v := viper.New() v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.False(t, params.GuessLanguage) @@ -516,13 +536,15 @@ func TestLoadHeartbeatParams_IsUnsavedEntity(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("is-unsaved-entity", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.IsUnsavedEntity) } func TestLoadHeartbeatParams_IsWrite(t *testing.T) { + ctx := context.Background() + tests := map[string]bool{ "is write": true, "is no write": false, @@ -534,7 +556,7 @@ func TestLoadHeartbeatParams_IsWrite(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("write", isWrite) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, isWrite, *params.IsWrite) @@ -546,7 +568,7 @@ func TestLoadHeartbeatParams_IsWrite_Unset(t *testing.T) { v := viper.New() v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Nil(t, params.IsWrite) @@ -557,7 +579,7 @@ func TestLoadHeartbeatParams_Language(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("language", "Go") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageGo.String(), *params.Language) @@ -568,7 +590,7 @@ func TestLoadHeartbeatParams_LanguageAlternate(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("alternate-language", "Go") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageGo.String(), params.LanguageAlternate) @@ -580,7 +602,7 @@ func TestLoadHeartbeatParams_LineNumber(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("lineno", 42) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, 42, *params.LineNumber) @@ -591,7 +613,7 @@ func TestLoadHeartbeatParams_LineNumber_Zero(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("lineno", 0) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Zero(t, *params.LineNumber) @@ -601,7 +623,7 @@ func TestLoadHeartbeatParams_LineNumber_Unset(t *testing.T) { v := viper.New() v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Nil(t, params.LineNumber) @@ -612,7 +634,7 @@ func TestLoadHeartbeatParams_LocalFile(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("local-file", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "/path/to/file", params.LocalFile) @@ -623,7 +645,7 @@ func TestLoadHeartbeatParams_Project(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("project", "billing") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "billing", params.Project.Override) @@ -633,13 +655,15 @@ func TestLoadHeartbeatParams_Project_Unset(t *testing.T) { v := viper.New() v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Empty(t, params.Project.Override) } func TestLoadHeartbeatParams_ProjectMap(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Entity string Regex regex.Regex @@ -648,23 +672,23 @@ func TestLoadHeartbeatParams_ProjectMap(t *testing.T) { }{ "simple regex": { Entity: "/home/user/projects/foo/file", - Regex: regexp.MustCompile("projects/foo"), + Regex: regex.NewRegexpWrap(regexp.MustCompile("projects/foo")), Project: "My Awesome Project", Expected: []project.MapPattern{ { Name: "My Awesome Project", - Regex: regexp.MustCompile("(?i)projects/foo"), + Regex: regex.NewRegexpWrap(regexp.MustCompile("(?i)projects/foo")), }, }, }, "regex with group replacement": { Entity: "/home/user/projects/bar123/file", - Regex: regexp.MustCompile(`^/home/user/projects/bar(\\d+)/`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`^/home/user/projects/bar(\\d+)/`)), Project: "project{0}", Expected: []project.MapPattern{ { Name: "project{0}", - Regex: regexp.MustCompile(`(?i)^/home/user/projects/bar(\\d+)/`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`(?i)^/home/user/projects/bar(\\d+)/`)), }, }, }, @@ -676,7 +700,7 @@ func TestLoadHeartbeatParams_ProjectMap(t *testing.T) { v.Set("entity", test.Entity) v.Set(fmt.Sprintf("projectmap.%s", test.Regex.String()), test.Project) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, params.Project.MapPatterns) @@ -685,6 +709,8 @@ func TestLoadHeartbeatParams_ProjectMap(t *testing.T) { } func TestLoadAPIParams_ProjectApiKey(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Entity string Regex regex.Regex @@ -692,37 +718,37 @@ func TestLoadAPIParams_ProjectApiKey(t *testing.T) { Expected []apikey.MapPattern }{ "simple regex": { - Regex: regexp.MustCompile("projects/foo"), + Regex: regex.NewRegexpWrap(regexp.MustCompile("projects/foo")), APIKey: "00000000-0000-4000-8000-000000000001", Expected: []apikey.MapPattern{ { APIKey: "00000000-0000-4000-8000-000000000001", - Regex: regexp.MustCompile(`(?i)projects/foo`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`(?i)projects/foo`)), }, }, }, "complex regex": { - Regex: regexp.MustCompile(`^/home/user/projects/bar(\\d+)/`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`^/home/user/projects/bar(\\d+)/`)), APIKey: "00000000-0000-4000-8000-000000000002", Expected: []apikey.MapPattern{ { APIKey: "00000000-0000-4000-8000-000000000002", - Regex: regexp.MustCompile(`(?i)^/home/user/projects/bar(\\d+)/`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`(?i)^/home/user/projects/bar(\\d+)/`)), }, }, }, "case insensitive": { - Regex: regexp.MustCompile("projects/foo"), + Regex: regex.NewRegexpWrap(regexp.MustCompile("projects/foo")), APIKey: "00000000-0000-4000-8000-000000000001", Expected: []apikey.MapPattern{ { APIKey: "00000000-0000-4000-8000-000000000001", - Regex: regexp.MustCompile(`(?i)projects/foo`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`(?i)projects/foo`)), }, }, }, "api key equal to default": { - Regex: regexp.MustCompile(`/some/path`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`/some/path`)), APIKey: "00000000-0000-4000-8000-000000000000", Expected: nil, }, @@ -734,7 +760,7 @@ func TestLoadAPIParams_ProjectApiKey(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set(fmt.Sprintf("project_api_key.%s", test.Regex.String()), test.APIKey) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, params.KeyPatterns) @@ -743,23 +769,25 @@ func TestLoadAPIParams_ProjectApiKey(t *testing.T) { } func TestLoadAPIParams_ProjectApiKey_ParseConfig(t *testing.T) { + ctx := context.Background() + v := viper.New() v.Set("config", "testdata/.wakatime.cfg") v.Set("entity", "testdata/heartbeat_go.json") - configFile, err := inipkg.FilePath(v) + configFile, err := inipkg.FilePath(ctx, v) require.NoError(t, err) err = inipkg.ReadInConfig(v, configFile) require.NoError(t, err) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(ctx, v) require.NoError(t, err) expected := []apikey.MapPattern{ { APIKey: "00000000-0000-4000-8000-000000000001", - Regex: regex.MustCompile("(?i)/some/path"), + Regex: regex.NewRegexpWrap(regexp.MustCompile("(?i)/some/path")), }, } @@ -770,7 +798,7 @@ func TestLoadAPIParams_APIKeyPrefixSupported(t *testing.T) { v := viper.New() v.Set("key", "waka_00000000-0000-4000-8000-000000000000") - _, err := cmdparams.LoadAPIParams(v) + _, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) } @@ -779,7 +807,7 @@ func TestLoadHeartbeatParams_Time(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("time", 1590609206.1) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, 1590609206.1, params.Time) @@ -789,7 +817,7 @@ func TestLoadHeartbeatParams_Time_Default(t *testing.T) { v := viper.New() v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) now := float64(time.Now().UnixNano()) / 1000000000 @@ -804,7 +832,7 @@ func TestLoadHeartbeatParams_Filter_Exclude(t *testing.T) { v.Set("settings.exclude", []string{".+", "wakatime.+"}) v.Set("settings.ignore", []string{".?", "wakatime.?"}) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) require.Len(t, params.Filter.Exclude, 6) @@ -821,7 +849,7 @@ func TestLoadHeartbeatParams_Filter_Exclude_Multiline(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("settings.ignore", "\t.?\n\twakatime.? \t\n") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) require.Len(t, params.Filter.Exclude, 2) @@ -834,7 +862,7 @@ func TestLoadHeartbeatParams_Filter_Exclude_IgnoresInvalidRegex(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("exclude", []string{".*", "["}) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) require.Len(t, params.Filter.Exclude, 1) @@ -853,7 +881,7 @@ func TestLoadHeartbeatParams_Filter_Exclude_PerlRegexPatterns(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("exclude", []string{pattern}) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) require.Len(t, params.Filter.Exclude, 1) @@ -867,7 +895,7 @@ func TestLoadHeartbeatParams_Filter_ExcludeUnknownProject(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("exclude-unknown-project", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.Filter.ExcludeUnknownProject) @@ -879,7 +907,7 @@ func TestLoadHeartbeatParams_Filter_ExcludeUnknownProject_FromConfig(t *testing. v.Set("exclude-unknown-project", false) v.Set("settings.exclude_unknown_project", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.Filter.ExcludeUnknownProject) @@ -891,7 +919,7 @@ func TestLoadHeartbeatParams_Filter_Include(t *testing.T) { v.Set("include", []string{".*", "wakatime.*"}) v.Set("settings.include", []string{".+", "wakatime.+"}) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) require.Len(t, params.Filter.Include, 4) @@ -906,7 +934,7 @@ func TestLoadHeartbeatParams_Filter_Include_IgnoresInvalidRegex(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("include", []string{".*", "["}) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) require.Len(t, params.Filter.Include, 1) @@ -914,6 +942,8 @@ func TestLoadHeartbeatParams_Filter_Include_IgnoresInvalidRegex(t *testing.T) { } func TestLoadHeartbeatParams_Filter_Include_PerlRegexPatterns(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "negative lookahead": `^/var/(?!www/).*`, "positive lookahead": `^/var/(?=www/).*`, @@ -925,7 +955,7 @@ func TestLoadHeartbeatParams_Filter_Include_PerlRegexPatterns(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("include", []string{pattern}) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) require.Len(t, params.Filter.Include, 1) @@ -939,7 +969,7 @@ func TestLoadHeartbeatParams_Filter_IncludeOnlyWithProjectFile(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("include-only-with-project-file", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.Filter.IncludeOnlyWithProjectFile) @@ -951,13 +981,15 @@ func TestLoadHeartbeatParams_Filter_IncludeOnlyWithProjectFile_FromConfig(t *tes v.Set("include-only-with-project-file", false) v.Set("settings.include_only_with_project_file", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.Filter.IncludeOnlyWithProjectFile) } func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_True(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "lowercase": "true", "uppercase": "TRUE", @@ -970,17 +1002,19 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_True(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("hide-branch-names", viperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideBranchNames: []regex.Regex{regex.MustCompile(".*")}, + HideBranchNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) }) } } func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_False(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "lowercase": "false", "uppercase": "FALSE", @@ -993,17 +1027,19 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_False(t *testing.T) v.Set("entity", "/path/to/file") v.Set("hide-branch-names", viperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideBranchNames: []regex.Regex{regex.MustCompile("a^")}, + HideBranchNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("a^"))}, }, params.Sanitize) }) } } func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_List(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { ViperValue string Expected []regex.Regex @@ -1011,14 +1047,14 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_List(t *testing.T) { "regex": { ViperValue: "fix.*", Expected: []regex.Regex{ - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, }, "regex list": { ViperValue: ".*secret.*\nfix.*", Expected: []regex.Regex{ - regexp.MustCompile(".*secret.*"), - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile(".*secret.*")), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, }, } @@ -1029,7 +1065,7 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_List(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("hide-branch-names", test.ViperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ @@ -1047,11 +1083,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_FlagTakesPrecedence( v.Set("settings.hide_branchnames", "ignored") v.Set("settings.hidebranchnames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideBranchNames: []regex.Regex{regexp.MustCompile(".*")}, + HideBranchNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1062,11 +1098,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_ConfigTakesPrecedenc v.Set("settings.hide_branchnames", "ignored") v.Set("settings.hidebranchnames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideBranchNames: []regex.Regex{regexp.MustCompile(".*")}, + HideBranchNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1076,11 +1112,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_ConfigDeprecatedOneT v.Set("settings.hide_branchnames", "true") v.Set("settings.hidebranchnames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideBranchNames: []regex.Regex{regexp.MustCompile(".*")}, + HideBranchNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1089,11 +1125,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_ConfigDeprecatedTwo( v.Set("entity", "/path/to/file") v.Set("settings.hidebranchnames", "true") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideBranchNames: []regex.Regex{regexp.MustCompile(".*")}, + HideBranchNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1102,7 +1138,7 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_InvalidRegex(t *test v.Set("entity", "/path/to/file") v.Set("hide-branch-names", ".*secret.*\n[0-9+") - _, err := cmdparams.LoadHeartbeatParams(v) + _, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.Error(t, err) assert.True(t, strings.HasPrefix( @@ -1114,6 +1150,8 @@ func TestLoadHeartbeatParams_SanitizeParams_HideBranchNames_InvalidRegex(t *test } func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_True(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "lowercase": "true", "uppercase": "TRUE", @@ -1126,17 +1164,19 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_True(t *testing.T) v.Set("entity", "/path/to/file") v.Set("hide-project-names", viperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideProjectNames: []regex.Regex{regexp.MustCompile(".*")}, + HideProjectNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) }) } } func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_False(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "lowercase": "false", "uppercase": "FALSE", @@ -1149,17 +1189,19 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_False(t *testing.T) v.Set("entity", "/path/to/file") v.Set("hide-project-names", viperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideProjectNames: []regex.Regex{regexp.MustCompile("a^")}, + HideProjectNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("a^"))}, }, params.Sanitize) }) } } func TestLoadHeartbeatParams_SanitizeParams_HideProjecthNames_List(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { ViperValue string Expected []regex.Regex @@ -1167,14 +1209,14 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjecthNames_List(t *testing.T) "regex": { ViperValue: "fix.*", Expected: []regex.Regex{ - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, }, "regex list": { ViperValue: ".*secret.*\nfix.*", Expected: []regex.Regex{ - regexp.MustCompile(".*secret.*"), - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile(".*secret.*")), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, }, } @@ -1185,7 +1227,7 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjecthNames_List(t *testing.T) v.Set("entity", "/path/to/file") v.Set("hide-project-names", test.ViperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ @@ -1203,11 +1245,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_FlagTakesPrecedence v.Set("settings.hide_projectnames", "ignored") v.Set("settings.hideprojectnames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideProjectNames: []regex.Regex{regexp.MustCompile(".*")}, + HideProjectNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1218,11 +1260,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_ConfigTakesPreceden v.Set("settings.hide_projectnames", "ignored") v.Set("settings.hideprojectnames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideProjectNames: []regex.Regex{regexp.MustCompile(".*")}, + HideProjectNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1232,11 +1274,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_ConfigDeprecatedOne v.Set("settings.hide_projectnames", "true") v.Set("settings.hideprojectnames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideProjectNames: []regex.Regex{regexp.MustCompile(".*")}, + HideProjectNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1245,11 +1287,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_ConfigDeprecatedTwo v.Set("entity", "/path/to/file") v.Set("settings.hideprojectnames", "true") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideProjectNames: []regex.Regex{regexp.MustCompile(".*")}, + HideProjectNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1258,7 +1300,7 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_InvalidRegex(t *tes v.Set("entity", "/path/to/file") v.Set("hide-project-names", ".*secret.*\n[0-9+") - _, err := cmdparams.LoadHeartbeatParams(v) + _, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.Error(t, err) assert.True(t, strings.HasPrefix( @@ -1270,6 +1312,8 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectNames_InvalidRegex(t *tes } func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_True(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "lowercase": "true", "uppercase": "TRUE", @@ -1282,17 +1326,19 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_True(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("hide-file-names", viperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideFileNames: []regex.Regex{regexp.MustCompile(".*")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) }) } } func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_False(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "lowercase": "false", "uppercase": "FALSE", @@ -1305,17 +1351,19 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_False(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("hide-file-names", viperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideFileNames: []regex.Regex{regexp.MustCompile("a^")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("a^"))}, }, params.Sanitize) }) } } func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_List(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { ViperValue string Expected []regex.Regex @@ -1323,14 +1371,14 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_List(t *testing.T) { "regex": { ViperValue: "fix.*", Expected: []regex.Regex{ - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, }, "regex list": { ViperValue: ".*secret.*\nfix.*", Expected: []regex.Regex{ - regexp.MustCompile(".*secret.*"), - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile(".*secret.*")), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, }, } @@ -1341,7 +1389,7 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_List(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("hide-file-names", test.ViperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ @@ -1361,11 +1409,11 @@ func TestLoadheartbeatParams_SanitizeParams_HideFileNames_FlagTakesPrecedence(t v.Set("settings.hide_filenames", "ignored") v.Set("settings.hidefilenames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideFileNames: []regex.Regex{regexp.MustCompile(".*")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1378,11 +1426,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_FlagDeprecatedOneTakes v.Set("settings.hide_filenames", "ignored") v.Set("settings.hidefilenames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideFileNames: []regex.Regex{regexp.MustCompile(".*")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1394,11 +1442,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_FlagDeprecatedTwoTakes v.Set("settings.hide_filenames", "ignored") v.Set("settings.hidefilenames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideFileNames: []regex.Regex{regexp.MustCompile(".*")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1409,11 +1457,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_ConfigTakesPrecedence( v.Set("settings.hide_filenames", "ignored") v.Set("settings.hidefilenames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideFileNames: []regex.Regex{regexp.MustCompile(".*")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1423,11 +1471,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_ConfigDeprecatedOneTak v.Set("settings.hide_filenames", "true") v.Set("settings.hidefilenames", "ignored") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideFileNames: []regex.Regex{regexp.MustCompile(".*")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1436,11 +1484,11 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_ConfigDeprecatedTwo(t v.Set("entity", "/path/to/file") v.Set("settings.hidefilenames", "true") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ - HideFileNames: []regex.Regex{regexp.MustCompile(".*")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }, params.Sanitize) } @@ -1449,7 +1497,7 @@ func TestLoadHeartbeatParams_SanitizeParams_HideFileNames_InvalidRegex(t *testin v.Set("entity", "/path/to/file") v.Set("hide-file-names", ".*secret.*\n[0-9+") - _, err := cmdparams.LoadHeartbeatParams(v) + _, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.Error(t, err) assert.True(t, strings.HasPrefix( @@ -1465,7 +1513,7 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectFolder(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("hide-project-folder", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ @@ -1478,7 +1526,7 @@ func TestLoadHeartbeatParams_SanitizeParams_HideProjectFolder_ConfigTakesPrecede v.Set("entity", "/path/to/file") v.Set("settings.hide_project_folder", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ @@ -1491,7 +1539,7 @@ func TestLoadHeartbeatParams_SanitizeParams_OverrideProjectPath(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("project-folder", "/custom-path") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.SanitizeParams{ @@ -1512,15 +1560,17 @@ func TestLoadHeartbeatParams_SubmodulesDisabled_True(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("git.submodules_disabled", viperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) - assert.Equal(t, []regex.Regex{regexp.MustCompile(".*")}, params.Project.SubmodulesDisabled) + assert.Equal(t, []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, params.Project.SubmodulesDisabled) }) } } func TestLoadHeartbeatParams_SubmodulesDisabled_False(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "lowercase": "false", "uppercase": "FALSE", @@ -1533,15 +1583,17 @@ func TestLoadHeartbeatParams_SubmodulesDisabled_False(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("git.submodules_disabled", viperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) - assert.Equal(t, params.Project.SubmodulesDisabled, []regex.Regex{regexp.MustCompile("a^")}) + assert.Equal(t, params.Project.SubmodulesDisabled, []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("a^"))}) }) } } func TestLoadHeartbeatsParams_SubmodulesDisabled_List(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { ViperValue string Expected []regex.Regex @@ -1549,14 +1601,14 @@ func TestLoadHeartbeatsParams_SubmodulesDisabled_List(t *testing.T) { "regex": { ViperValue: "fix.*", Expected: []regex.Regex{ - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, }, "regex_list": { ViperValue: "\n.*secret.*\nfix.*", Expected: []regex.Regex{ - regexp.MustCompile(".*secret.*"), - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile(".*secret.*")), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, }, } @@ -1568,7 +1620,7 @@ func TestLoadHeartbeatsParams_SubmodulesDisabled_List(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("git.submodules_disabled", test.ViperValue) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, params.Project.SubmodulesDisabled) @@ -1577,6 +1629,8 @@ func TestLoadHeartbeatsParams_SubmodulesDisabled_List(t *testing.T) { } func TestLoadHeartbeatsParams_SubmoduleProjectMap(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Entity string Regex regex.Regex @@ -1585,23 +1639,23 @@ func TestLoadHeartbeatsParams_SubmoduleProjectMap(t *testing.T) { }{ "simple regex": { Entity: "/home/user/projects/foo/file", - Regex: regexp.MustCompile("projects/foo"), + Regex: regex.NewRegexpWrap(regexp.MustCompile("projects/foo")), Project: "My Awesome Project", Expected: []project.MapPattern{ { Name: "My Awesome Project", - Regex: regexp.MustCompile("(?i)projects/foo"), + Regex: regex.NewRegexpWrap(regexp.MustCompile("(?i)projects/foo")), }, }, }, "regex with group replacement": { Entity: "/home/user/projects/bar123/file", - Regex: regexp.MustCompile(`^/home/user/projects/bar(\\d+)/`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`^/home/user/projects/bar(\\d+)/`)), Project: "project{0}", Expected: []project.MapPattern{ { Name: "project{0}", - Regex: regexp.MustCompile(`(?i)^/home/user/projects/bar(\\d+)/`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`(?i)^/home/user/projects/bar(\\d+)/`)), }, }, }, @@ -1613,7 +1667,7 @@ func TestLoadHeartbeatsParams_SubmoduleProjectMap(t *testing.T) { v.Set("entity", test.Entity) v.Set(fmt.Sprintf("git_submodule_projectmap.%s", test.Regex.String()), test.Project) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, params.Project.SubmoduleMapPatterns) @@ -1626,7 +1680,7 @@ func TestLoadAPIParams_Plugin(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("plugin", "plugin/10.0.0") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "plugin/10.0.0", params.Plugin) @@ -1636,7 +1690,7 @@ func TestLoadAPIParams_Plugin_Unset(t *testing.T) { v := viper.New() v.Set("key", "00000000-0000-4000-8000-000000000000") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Empty(t, params.Plugin) @@ -1648,7 +1702,7 @@ func TestLoadAPIParams_Timeout_FlagTakesPrecedence(t *testing.T) { v.Set("timeout", 5) v.Set("settings.timeout", 10) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, 5*time.Second, params.Timeout) @@ -1659,7 +1713,7 @@ func TestLoadAPIParams_Timeout_FromConfig(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("settings.timeout", 10) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, 10*time.Second, params.Timeout) @@ -1671,7 +1725,7 @@ func TestLoadOfflineParams_Disabled_ConfigTakesPrecedence(t *testing.T) { v.Set("disableoffline", false) v.Set("settings.offline", false) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.True(t, params.Disabled) } @@ -1681,7 +1735,7 @@ func TestLoadOfflineParams_Disabled_FlagDeprecatedTakesPrecedence(t *testing.T) v.Set("disable-offline", false) v.Set("disableoffline", true) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.True(t, params.Disabled) } @@ -1690,7 +1744,7 @@ func TestLoadOfflineParams_Disabled_FromFlag(t *testing.T) { v := viper.New() v.Set("disable-offline", true) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.True(t, params.Disabled) } @@ -1700,7 +1754,7 @@ func TestLoadOfflineParams_RateLimit_FlagTakesPrecedence(t *testing.T) { v.Set("heartbeat-rate-limit-seconds", 5) v.Set("settings.heartbeat_rate_limit_seconds", 10) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Equal(t, time.Duration(5)*time.Second, params.RateLimit) } @@ -1709,7 +1763,7 @@ func TestLoadOfflineParams_RateLimit_FromConfig(t *testing.T) { v := viper.New() v.Set("settings.heartbeat_rate_limit_seconds", 10) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Equal(t, time.Duration(10)*time.Second, params.RateLimit) } @@ -1718,7 +1772,7 @@ func TestLoadOfflineParams_RateLimit_Zero(t *testing.T) { v := viper.New() v.Set("heartbeat-rate-limit-seconds", "0") - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Zero(t, params.RateLimit) } @@ -1727,7 +1781,7 @@ func TestLoadOfflineParams_RateLimit_Default(t *testing.T) { v := viper.New() v.SetDefault("heartbeat-rate-limit-seconds", 20) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Equal(t, time.Duration(20)*time.Second, params.RateLimit) } @@ -1736,7 +1790,7 @@ func TestLoadOfflineParams_RateLimit_NegativeNumber(t *testing.T) { v := viper.New() v.Set("heartbeat-rate-limit-seconds", -1) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Zero(t, params.RateLimit) } @@ -1745,7 +1799,7 @@ func TestLoadOfflineParams_RateLimit_NonIntegerValue(t *testing.T) { v := viper.New() v.Set("heartbeat-rate-limit-seconds", "invalid") - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Zero(t, params.RateLimit) } @@ -1754,7 +1808,7 @@ func TestLoadOfflineParams_LastSentAt(t *testing.T) { v := viper.New() v.Set("internal.heartbeats_last_sent_at", "2021-08-30T18:50:42-03:00") - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) lastSentAt, err := time.Parse(inipkg.DateFormat, "2021-08-30T18:50:42-03:00") require.NoError(t, err) @@ -1766,7 +1820,7 @@ func TestLoadOfflineParams_LastSentAt_Err(t *testing.T) { v := viper.New() v.Set("internal.heartbeats_last_sent_at", "2021-08-30") - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Zero(t, params.LastSentAt) } @@ -1776,7 +1830,7 @@ func TestLoadOfflineParams_LastSentAtFuture(t *testing.T) { lastSentAt := time.Now().Add(time.Duration(2) * time.Hour) v.Set("internal.heartbeats_last_sent_at", lastSentAt.Format(inipkg.DateFormat)) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.LessOrEqual(t, params.LastSentAt, time.Now()) } @@ -1785,7 +1839,7 @@ func TestLoadOfflineParams_SyncMax(t *testing.T) { v := viper.New() v.Set("sync-offline-activity", 42) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Equal(t, 42, params.SyncMax) } @@ -1794,7 +1848,7 @@ func TestLoadOfflineParams_SyncMax_Zero(t *testing.T) { v := viper.New() v.Set("sync-offline-activity", "0") - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Zero(t, params.SyncMax) } @@ -1803,7 +1857,7 @@ func TestLoadOfflineParams_SyncMax_Default(t *testing.T) { v := viper.New() v.SetDefault("sync-offline-activity", 1000) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Equal(t, 1000, params.SyncMax) } @@ -1812,7 +1866,7 @@ func TestLoadOfflineParams_SyncMax_NegativeNumber(t *testing.T) { v := viper.New() v.Set("sync-offline-activity", -1) - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Zero(t, params.SyncMax) } @@ -1821,12 +1875,14 @@ func TestLoadOfflineParams_SyncMax_NonIntegerValue(t *testing.T) { v := viper.New() v.Set("sync-offline-activity", "invalid") - params := cmdparams.LoadOfflineParams(v) + params := cmdparams.LoadOfflineParams(context.Background(), v) assert.Zero(t, params.SyncMax) } func TestLoadAPIParams_APIKey(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { ViperAPIKey string ViperAPIKeyConfig string @@ -1870,7 +1926,7 @@ func TestLoadAPIParams_APIKey(t *testing.T) { v.Set("settings.api_key", test.ViperAPIKeyConfig) v.Set("settings.apikey", test.ViperAPIKeyConfigOld) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, params) @@ -1882,17 +1938,18 @@ func TestLoadAPIParams_APIKeyUnset(t *testing.T) { v := viper.New() v.Set("key", "") - _, err := cmdparams.LoadAPIParams(v) + _, err := cmdparams.LoadAPIParams(context.Background(), v) require.Error(t, err) var errauth api.ErrAuth assert.ErrorAs(t, err, &errauth) - assert.EqualError(t, errauth, "api key not found or empty") } func TestLoadAPIParams_APIKeyInvalid(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "invalid format 1": "not-uuid", "invalid format 2": "00000000-0000-0000-8000-000000000000", @@ -1904,7 +1961,7 @@ func TestLoadAPIParams_APIKeyInvalid(t *testing.T) { v := viper.New() v.Set("key", value) - _, err := cmdparams.LoadAPIParams(v) + _, err := cmdparams.LoadAPIParams(ctx, v) require.Error(t, err) var errauth api.ErrAuth @@ -1920,13 +1977,13 @@ func TestLoadAPIParams_ApiKey_SettingTakePrecedence(t *testing.T) { v.Set("config", "testdata/.wakatime.cfg") v.Set("entity", "testdata/heartbeat_go.json") - configFile, err := inipkg.FilePath(v) + configFile, err := inipkg.FilePath(context.Background(), v) require.NoError(t, err) err = inipkg.ReadInConfig(v, configFile) require.NoError(t, err) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "00000000-0000-4000-8000-000000000000", params.Key) @@ -1937,13 +1994,13 @@ func TestLoadAPIParams_ApiKey_FromVault(t *testing.T) { v.Set("config", "testdata/.wakatime-vault.cfg") v.Set("entity", "testdata/heartbeat_go.json") - configFile, err := inipkg.FilePath(v) + configFile, err := inipkg.FilePath(context.Background(), v) require.NoError(t, err) err = inipkg.ReadInConfig(v, configFile) require.NoError(t, err) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "00000000-0000-4000-8000-000000000000", params.Key) @@ -1954,17 +2011,19 @@ func TestLoadParams_ApiKey_FromVault_Err_Darwin(t *testing.T) { t.Skip("Skipping because OS is not darwin.") } + ctx := context.Background() + v := viper.New() v.Set("config", "testdata/.wakatime-vault-error.cfg") v.Set("entity", "testdata/heartbeat_go.json") - configFile, err := inipkg.FilePath(v) + configFile, err := inipkg.FilePath(ctx, v) require.NoError(t, err) err = inipkg.ReadInConfig(v, configFile) require.NoError(t, err) - _, err = cmdparams.LoadAPIParams(v) + _, err = cmdparams.LoadAPIParams(ctx, v) assert.EqualError(t, err, "failed to read api key from vault: exit status 1") } @@ -1977,7 +2036,7 @@ func TestLoadAPIParams_APIKeyFromEnv(t *testing.T) { defer os.Unsetenv("WAKATIME_API_KEY") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "00000000-0000-4000-8000-000000000000", params.Key) @@ -1991,13 +2050,12 @@ func TestLoadAPIParams_APIKeyFromEnvInvalid(t *testing.T) { defer os.Unsetenv("WAKATIME_API_KEY") - _, err = cmdparams.LoadAPIParams(v) + _, err = cmdparams.LoadAPIParams(context.Background(), v) require.Error(t, err) var errauth api.ErrAuth assert.ErrorAs(t, err, &errauth) - assert.EqualError(t, errauth, "invalid api key format") } @@ -2010,13 +2068,15 @@ func TestLoadAPIParams_APIKeyFromEnv_ConfigTakesPrecedence(t *testing.T) { defer os.Unsetenv("WAKATIME_API_KEY") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "00000000-0000-4000-8000-000000000000", params.Key) } func TestLoadAPIParams_APIUrl(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { ViperAPIUrl string ViperAPIUrlConfig string @@ -2085,7 +2145,7 @@ func TestLoadAPIParams_APIUrl(t *testing.T) { v.Set("apiurl", test.ViperAPIUrlOld) v.Set("settings.api_url", test.ViperAPIUrlConfig) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, params) @@ -2097,7 +2157,7 @@ func TestLoadAPIParams_Url_Default(t *testing.T) { v := viper.New() v.Set("key", "00000000-0000-4000-8000-000000000000") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, api.BaseURL, params.URL) @@ -2108,12 +2168,11 @@ func TestLoadAPIParams_Url_InvalidFormat(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("api-url", "http://in valid") - _, err := cmdparams.LoadAPIParams(v) + _, err := cmdparams.LoadAPIParams(context.Background(), v) var errauth api.ErrAuth assert.ErrorAs(t, err, &errauth) - assert.EqualError(t, errauth, `invalid api url: parse "http://in valid": invalid character " " in host name`) } @@ -2124,7 +2183,7 @@ func TestLoadAPIParams_BackoffAt(t *testing.T) { v.Set("internal.backoff_at", "2021-08-30T18:50:42-03:00") v.Set("internal.backoff_retries", "3") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) backoffAt, err := time.Parse(inipkg.DateFormat, "2021-08-30T18:50:42-03:00") @@ -2146,7 +2205,7 @@ func TestLoadAPIParams_BackoffAtErr(t *testing.T) { v.Set("internal.backoff_at", "2021-08-30") v.Set("internal.backoff_retries", "2") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, cmdparams.API{ @@ -2167,7 +2226,7 @@ func TestLoadAPIParams_BackoffAtFuture(t *testing.T) { v.Set("internal.backoff_at", backoff.Format(inipkg.DateFormat)) v.Set("internal.backoff_retries", "3") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, 3, params.BackoffRetries) @@ -2180,7 +2239,7 @@ func TestLoadAPIParams_DisableSSLVerify_FlagTakesPrecedence(t *testing.T) { v.Set("no-ssl-verify", true) v.Set("settings.no_ssl_verify", false) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.DisableSSLVerify) @@ -2191,7 +2250,7 @@ func TestLoadAPIParams_DisableSSLVerify_FromConfig(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("settings.no_ssl_verify", true) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.DisableSSLVerify) @@ -2201,13 +2260,15 @@ func TestLoadAPIParams_DisableSSLVerify_Default(t *testing.T) { v := viper.New() v.Set("key", "00000000-0000-4000-8000-000000000000") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.False(t, params.DisableSSLVerify) } func TestLoadAPIParams_ProxyURL(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "https": "https://john:secret@example.org:8888", "http": "http://john:secret@example.org:8888", @@ -2222,7 +2283,7 @@ func TestLoadAPIParams_ProxyURL(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("proxy", proxyURL) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(ctx, v) require.NoError(t, err) assert.Equal(t, proxyURL, params.ProxyURL) @@ -2236,7 +2297,7 @@ func TestLoadAPIParams_ProxyURL_FlagTakesPrecedence(t *testing.T) { v.Set("proxy", "https://john:secret@example.org:8888") v.Set("settings.proxy", "ignored") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "https://john:secret@example.org:8888", params.ProxyURL) @@ -2252,7 +2313,7 @@ func TestLoadAPIParams_ProxyURL_UserDefinedTakesPrecedenceOverEnvironment(t *tes defer os.Unsetenv("HTTPS_PROXY") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "https://john:secret@example.org:8888", params.ProxyURL) @@ -2263,7 +2324,7 @@ func TestLoadAPIParams_ProxyURL_FromConfig(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("settings.proxy", "https://john:secret@example.org:8888") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "https://john:secret@example.org:8888", params.ProxyURL) @@ -2278,7 +2339,7 @@ func TestLoadAPIParams_ProxyURL_FromEnvironment(t *testing.T) { defer os.Unsetenv("HTTPS_PROXY") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "https://john:secret@example.org:8888", params.ProxyURL) @@ -2293,7 +2354,7 @@ func TestLoadAPIParams_ProxyURL_NoProxyFromEnvironment(t *testing.T) { defer os.Unsetenv("NO_PROXY") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Empty(t, params.ProxyURL) @@ -2304,12 +2365,11 @@ func TestLoadAPIParams_ProxyURL_InvalidFormat(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("proxy", "ftp://john:secret@example.org:8888") - _, err := cmdparams.LoadAPIParams(v) + _, err := cmdparams.LoadAPIParams(context.Background(), v) var errauth api.ErrAuth assert.ErrorAs(t, err, &errauth) - assert.EqualError( t, err, @@ -2325,7 +2385,7 @@ func TestLoadAPIParams_SSLCertFilepath_FlagTakesPrecedence(t *testing.T) { home, err := os.UserHomeDir() require.NoError(t, err) - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, filepath.Join(home, "/path/to/cert.pem"), params.SSLCertFilepath) @@ -2336,7 +2396,7 @@ func TestLoadAPIParams_SSLCertFilepath_FromConfig(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("settings.ssl_certs_file", "/path/to/cert.pem") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "/path/to/cert.pem", params.SSLCertFilepath) @@ -2353,7 +2413,7 @@ func TestLoadAPIParams_Hostname_FlagTakesPrecedence(t *testing.T) { defer os.Unsetenv("GITPOD_WORKSPACE_ID") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "my-machine", params.Hostname) @@ -2364,7 +2424,7 @@ func TestLoadAPIParams_Hostname_FromConfig(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("settings.hostname", "my-machine") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "my-machine", params.Hostname) @@ -2380,7 +2440,7 @@ func TestLoadAPIParams_Hostname_ConfigTakesPrecedence(t *testing.T) { defer os.Unsetenv("GITPOD_WORKSPACE_ID") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "my-machine", params.Hostname) @@ -2395,7 +2455,7 @@ func TestLoadAPIParams_Hostname_FromGitpodEnv(t *testing.T) { defer os.Unsetenv("GITPOD_WORKSPACE_ID") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "Gitpod", params.Hostname) @@ -2405,7 +2465,7 @@ func TestLoadAPIParams_Hostname_DefaultFromSystem(t *testing.T) { v := viper.New() v.Set("key", "00000000-0000-4000-8000-000000000000") - params, err := cmdparams.LoadAPIParams(v) + params, err := cmdparams.LoadAPIParams(context.Background(), v) require.NoError(t, err) expected, err := os.Hostname() @@ -2486,7 +2546,7 @@ func TestAPI_String(t *testing.T) { KeyPatterns: []apikey.MapPattern{ { APIKey: "00000000-0000-4000-8000-000000000001", - Regex: regex.MustCompile("^/api/v1/"), + Regex: regex.NewRegexpWrap(regexp.MustCompile("^/api/v1/")), }, }, Plugin: "my-plugin", @@ -2508,9 +2568,9 @@ func TestAPI_String(t *testing.T) { func TestFilterParams_String(t *testing.T) { filterparams := cmdparams.FilterParams{ - Exclude: []regex.Regex{regex.MustCompile("^/exclude")}, + Exclude: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("^/exclude"))}, ExcludeUnknownProject: true, - Include: []regex.Regex{regex.MustCompile("^/include")}, + Include: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("^/include"))}, IncludeOnlyWithProjectFile: true, } @@ -2577,12 +2637,18 @@ func TestOffline_String(t *testing.T) { func TestProjectParams_String(t *testing.T) { projectparams := cmdparams.ProjectParams{ - Alternate: "alternate", - BranchAlternate: "branch-alternate", - MapPatterns: []project.MapPattern{{Name: "project-1", Regex: regex.MustCompile("^/regex")}}, - Override: "override", - SubmodulesDisabled: []regex.Regex{regexp.MustCompile(".*")}, - SubmoduleMapPatterns: []project.MapPattern{{Name: "awesome-project", Regex: regex.MustCompile("^/regex")}}, + Alternate: "alternate", + BranchAlternate: "branch-alternate", + MapPatterns: []project.MapPattern{{ + Name: "project-1", + Regex: regex.NewRegexpWrap(regexp.MustCompile("^/regex")), + }}, + Override: "override", + SubmodulesDisabled: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, + SubmoduleMapPatterns: []project.MapPattern{{ + Name: "awesome-project", + Regex: regex.NewRegexpWrap(regexp.MustCompile("^/regex")), + }}, } assert.Equal( @@ -2599,7 +2665,7 @@ func TestLoadHeartbeatParams_ProjectFromGitRemote(t *testing.T) { v.Set("git.project_from_git_remote", true) v.Set("entity", "/path/to/file") - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.True(t, params.Project.ProjectFromGitRemote) @@ -2607,10 +2673,10 @@ func TestLoadHeartbeatParams_ProjectFromGitRemote(t *testing.T) { func TestSanitizeParams_String(t *testing.T) { sanitizeparams := cmdparams.SanitizeParams{ - HideBranchNames: []regex.Regex{regex.MustCompile("^/hide")}, + HideBranchNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("^/hide"))}, HideProjectFolder: true, - HideFileNames: []regex.Regex{regex.MustCompile("^/hide")}, - HideProjectNames: []regex.Regex{regex.MustCompile("^/hide")}, + HideFileNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("^/hide"))}, + HideProjectNames: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("^/hide"))}, ProjectPathOverride: "path/to/project", } @@ -2664,7 +2730,7 @@ func TestLoadHeartbeatParams_ExtraHeartbeats_StdinReadOnlyOnce(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("extra-heartbeats", true) - params, err := cmdparams.LoadHeartbeatParams(v) + params, err := cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Len(t, params.ExtraHeartbeats, 2) @@ -2691,7 +2757,7 @@ func TestLoadHeartbeatParams_ExtraHeartbeats_StdinReadOnlyOnce(t *testing.T) { v.Set("entity", "/path/to/file") v.Set("extra-heartbeats", true) - params, err = cmdparams.LoadHeartbeatParams(v) + params, err = cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Len(t, params.ExtraHeartbeats, 2) @@ -2703,25 +2769,9 @@ func TestLoadHeartbeatParams_ExtraHeartbeats_StdinReadOnlyOnce(t *testing.T) { cmdparams.Once = sync.Once{} - params, err = cmdparams.LoadHeartbeatParams(v) + params, err = cmdparams.LoadHeartbeatParams(context.Background(), v) require.NoError(t, err) assert.Len(t, params.ExtraHeartbeats, 2) assert.Empty(t, params.ExtraHeartbeats[0].LanguageAlternate) } - -func captureLogs(dest io.Writer) func() { - // set verbose - log.SetVerbose(true) - - logOutput := log.Output() - - // will write to log output and dest - mw := io.MultiWriter(logOutput, dest) - - log.SetOutput(mw) - - return func() { - log.SetOutput(logOutput) - } -} diff --git a/cmd/run.go b/cmd/run.go index aaad374b..15fde8c2 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -2,8 +2,10 @@ package cmd import ( "bytes" + "context" "fmt" "io" + stdlog "log" "os" "path/filepath" "runtime/debug" @@ -47,113 +49,119 @@ type diagnostics struct { // RunE executes commands parsed from a command line. func RunE(cmd *cobra.Command, v *viper.Viper) error { + ctx := context.Background() + // force setup logging otherwise log goes to std out - _, err := SetupLogging(v) + logger, err := SetupLogging(ctx, v) if err != nil { - log.Fatalf("failed to setup logging: %s", err) + // log to std out as logger is not setup yet + stdlog.Fatalf("failed to setup logging: %s", err) } - err = parseConfigFiles(v) + // save logger to context + ctx = log.ToContext(ctx, logger) + + err = parseConfigFiles(ctx, v) if err != nil { - log.Errorf("failed to parse config files: %s", err) + logger.Errorf("failed to parse config files: %s", err) if v.IsSet("entity") { - _ = saveHeartbeats(v) + _ = saveHeartbeats(ctx, v) return exitcode.Err{Code: exitcode.ErrConfigFileParse} } } - // setup logging again to use config file settings - logFileParams, err := SetupLogging(v) + // setup logging again to use config file settings if available + logger, err = SetupLogging(ctx, v) if err != nil { - log.Fatalf("failed to setup logging: %s", err) + logger.Fatalf("failed to setup logging: %s", err) } // register all custom lexers if err := lexer.RegisterAll(); err != nil { - log.Fatalf("failed to register custom lexers: %s", err) + logger.Fatalf("failed to register custom lexers: %s", err) } // start profiling if enabled - if logFileParams.Metrics { - shutdown, err := metrics.StartProfiling() + if logger.IsMetricsEnabled() { + shutdown, err := metrics.StartProfiling(ctx) if err != nil { - log.Errorf("failed to start profiling: %s", err) + logger.Errorf("failed to start profiling: %s", err) + } else { + defer shutdown() } - - defer shutdown() } if v.GetBool("user-agent") { - log.Debugln("command: user-agent") + logger.Debugln("command: user-agent") - fmt.Println(heartbeat.UserAgent(vipertools.GetString(v, "plugin"))) + fmt.Println(heartbeat.UserAgent(ctx, vipertools.GetString(v, "plugin"))) return nil } if v.GetBool("version") { - log.Debugln("command: version") + logger.Debugln("command: version") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, runVersion) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), runVersion) } if v.IsSet("config-read") { - log.Debugln("command: config-read") + logger.Debugln("command: config-read") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, configread.Run) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), configread.Run) } if v.IsSet("config-write") { - log.Debugln("command: config-write") + logger.Debugln("command: config-write") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, configwrite.Run) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), configwrite.Run) } if v.GetBool("today") { - log.Debugln("command: today") + logger.Debugln("command: today") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, today.Run) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), today.Run) } if v.IsSet("today-goal") { - log.Debugln("command: today-goal") + logger.Debugln("command: today-goal") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, todaygoal.Run) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), todaygoal.Run) } if v.GetBool("file-experts") { - log.Debugln("command: file-experts") + logger.Debugln("command: file-experts") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, fileexperts.Run) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), fileexperts.Run) } if v.IsSet("entity") { - log.Debugln("command: heartbeat") + logger.Debugln("command: heartbeat") - return RunCmdWithOfflineSync(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, cmdheartbeat.Run) + return RunCmdWithOfflineSync(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), cmdheartbeat.Run) } if v.IsSet("sync-offline-activity") { - log.Debugln("command: sync-offline-activity") + logger.Debugln("command: sync-offline-activity") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, offlinesync.RunWithoutRateLimiting) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), offlinesync.RunWithoutRateLimiting) } if v.GetBool("offline-count") { - log.Debugln("command: offline-count") + logger.Debugln("command: offline-count") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, offlinecount.Run) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), offlinecount.Run) } if v.IsSet("print-offline-heartbeats") { - log.Debugln("command: print-offline-heartbeats") + logger.Debugln("command: print-offline-heartbeats") - return RunCmd(v, logFileParams.Verbose, logFileParams.SendDiagsOnErrors, offlineprint.Run) + return RunCmd(ctx, v, logger.IsVerboseEnabled(), logger.SendDiagsOnErrors(), offlineprint.Run) } - log.Warnf("one of the following parameters has to be provided: %s", strings.Join([]string{ + logger.Warnf("one of the following parameters has to be provided: %s", strings.Join([]string{ "--config-read", "--config-write", "--entity", @@ -172,9 +180,9 @@ func RunE(cmd *cobra.Command, v *viper.Viper) error { return exitcode.Err{Code: exitcode.ErrGeneric} } -func parseConfigFiles(v *viper.Viper) error { +func parseConfigFiles(ctx context.Context, v *viper.Viper) error { var configFiles = []struct { - fn func(v *viper.Viper) (string, error) + fn func(context.Context, *viper.Viper) (string, error) vp *viper.Viper merge bool }{ @@ -193,8 +201,10 @@ func parseConfigFiles(v *viper.Viper) error { }, } + logger := log.Extract(ctx) + for _, c := range configFiles { - configFile, err := c.fn(v) + configFile, err := c.fn(ctx, v) if err != nil { return fmt.Errorf("error getting config file path: %s", err) } @@ -205,7 +215,7 @@ func parseConfigFiles(v *viper.Viper) error { // check if file exists if _, err := os.Stat(configFile); os.IsNotExist(err) { - log.Debugf("config file %q not present or not accessible", configFile) + logger.Debugf("config file %q not present or not accessible", configFile) continue } @@ -216,7 +226,7 @@ func parseConfigFiles(v *viper.Viper) error { if c.merge { err = v.MergeConfigMap(c.vp.AllSettings()) if err != nil { - log.Warnf("failed to merge configuration file: %s", err) + logger.Warnf("failed to merge configuration file: %s", err) } } } @@ -225,16 +235,19 @@ func parseConfigFiles(v *viper.Viper) error { } // SetupLogging uses the --log-file param to configure logging to file or stdout. -func SetupLogging(v *viper.Viper) (*logfile.Params, error) { - logfileParams, err := logfile.LoadParams(v) +// It returns a logger with the configured settings or the default settings if it's not set. +func SetupLogging(ctx context.Context, v *viper.Viper) (*log.Logger, error) { + params, err := logfile.LoadParams(ctx, v) if err != nil { return nil, fmt.Errorf("failed to load log params: %s", err) } + logger := log.New(params.Verbose, params.SendDiagsOnErrors, params.Metrics) + logFile := os.Stdout - if !logfileParams.ToStdout { - dir := filepath.Dir(logfileParams.File) + if !params.ToStdout { + dir := filepath.Dir(params.File) if _, err := os.Stat(dir); os.IsNotExist(err) { err := os.MkdirAll(dir, 0750) if err != nil { @@ -242,52 +255,54 @@ func SetupLogging(v *viper.Viper) (*logfile.Params, error) { } } - logFile, err = os.OpenFile(logfileParams.File, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) // nolint:gosec + logFile, err = os.OpenFile(params.File, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) // nolint:gosec if err != nil { return nil, fmt.Errorf("error opening log file: %s", err) } - log.SetOutput(logFile) + logger.SetOutput(logFile) } - log.SetVerbose(logfileParams.Verbose) - log.SetJww(logfileParams.Verbose, logFile) + logger.SetVerbose(params.Verbose) + log.SetJww(params.Verbose, logFile) - return &logfileParams, nil + return logger, nil } // cmdFn represents a command function. -type cmdFn func(v *viper.Viper) (int, error) +type cmdFn func(ctx context.Context, v *viper.Viper) (int, error) // RunCmd runs a command function and exits with the exit code returned by // the command function. Will send diagnostic on any errors or panics. -func RunCmd(v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) error { - return runCmd(v, verbose, sendDiagsOnErrors, cmd) +func RunCmd(ctx context.Context, v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) error { + return runCmd(ctx, v, verbose, sendDiagsOnErrors, cmd) } // RunCmdWithOfflineSync runs a command function and exits with the exit code // returned by the command function. If command run was successful, it will execute // offline sync command afterwards. Will send diagnostic on any errors or panics. -func RunCmdWithOfflineSync(v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) error { - if err := runCmd(v, verbose, sendDiagsOnErrors, cmd); err != nil { +func RunCmdWithOfflineSync(ctx context.Context, v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) error { + if err := runCmd(ctx, v, verbose, sendDiagsOnErrors, cmd); err != nil { return err } - return runCmd(v, verbose, sendDiagsOnErrors, offlinesync.RunWithRateLimiting) + return runCmd(ctx, v, verbose, sendDiagsOnErrors, offlinesync.RunWithRateLimiting) } // runCmd contains the main logic of RunCmd. // It will send diagnostic on any errors or panics. // On panic, it will send diagnostic and exit with ErrGeneric exit code. // On error, it will only send diagnostic if sendDiagsOnErrors and verbose is true. -func runCmd(v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) (errresponse error) { +func runCmd(ctx context.Context, v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) (errresponse error) { logs := bytes.NewBuffer(nil) - resetLogs := captureLogs(logs) + resetLogs := captureLogs(ctx, logs) + + logger := log.Extract(ctx) // catch panics defer func() { if err := recover(); err != nil { - log.Errorf("panicked: %v. Stack: %s", err, string(debug.Stack())) + logger.Errorf("panicked: %v. Stack: %s", err, string(debug.Stack())) resetLogs() @@ -301,8 +316,8 @@ func runCmd(v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) (er diags.Logs = logs.String() } - if err := sendDiagnostics(v, diags); err != nil { - log.Warnf("failed to send diagnostics: %s", err) + if err := sendDiagnostics(ctx, v, diags); err != nil { + logger.Warnf("failed to send diagnostics: %s", err) } errresponse = exitcode.Err{Code: exitcode.ErrGeneric} @@ -312,7 +327,7 @@ func runCmd(v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) (er var err error // run command - exitCode, err := cmd(v) + exitCode, err := cmd(ctx, v) // nolint:nestif if err != nil { if errwaka, ok := err.(wakaerror.Error); ok { @@ -322,25 +337,25 @@ func runCmd(v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) (er } if verbose { - log.Errorf("failed to run command: %s", err) + logger.Errorf("failed to run command: %s", err) } resetLogs() if verbose && sendDiagsOnErrors { - if err := sendDiagnostics(v, + if err := sendDiagnostics(ctx, v, diagnostics{ Logs: logs.String(), OriginalError: err.Error(), Stack: string(debug.Stack()), }); err != nil { - log.Warnf("failed to send diagnostics: %s", err) + logger.Warnf("failed to send diagnostics: %s", err) } } } if exitCode != exitcode.Success { - log.Debugf("command failed with exit code %d", exitCode) + logger.Debugf("command failed with exit code %d", exitCode) errresponse = exitcode.Err{Code: exitCode} } @@ -348,14 +363,16 @@ func runCmd(v *viper.Viper, verbose bool, sendDiagsOnErrors bool, cmd cmdFn) (er return errresponse } -func saveHeartbeats(v *viper.Viper) int { - queueFilepath, err := offline.QueueFilepath(v) +func saveHeartbeats(ctx context.Context, v *viper.Viper) int { + logger := log.Extract(ctx) + + queueFilepath, err := offline.QueueFilepath(ctx, v) if err != nil { - log.Warnf("failed to load offline queue filepath: %s", err) + logger.Warnf("failed to load offline queue filepath: %s", err) } - if err := cmdoffline.SaveHeartbeats(v, nil, queueFilepath); err != nil { - log.Errorf("failed to save heartbeats to offline queue: %s", err) + if err := cmdoffline.SaveHeartbeats(ctx, v, nil, queueFilepath); err != nil { + logger.Errorf("failed to save heartbeats to offline queue: %s", err) return exitcode.ErrGeneric } @@ -363,13 +380,13 @@ func saveHeartbeats(v *viper.Viper) int { return exitcode.Success } -func sendDiagnostics(v *viper.Viper, d diagnostics) error { - paramAPI, err := params.LoadAPIParams(v) +func sendDiagnostics(ctx context.Context, v *viper.Viper, d diagnostics) error { + paramAPI, err := params.LoadAPIParams(ctx, v) if err != nil { return fmt.Errorf("failed to load API parameters: %s", err) } - c, err := cmdapi.NewClient(paramAPI) + c, err := cmdapi.NewClient(ctx, paramAPI) if err != nil { return fmt.Errorf("failed to initialize api client: %s", err) } @@ -380,25 +397,28 @@ func sendDiagnostics(v *viper.Viper, d diagnostics) error { diagnostic.Stack(d.Stack), } - err = c.SendDiagnostics(paramAPI.Plugin, d.Panicked, diagnostics...) + err = c.SendDiagnostics(ctx, paramAPI.Plugin, d.Panicked, diagnostics...) if err != nil { return fmt.Errorf("failed to send diagnostics to the API: %s", err) } - log.Debugln("successfully sent diagnostics") + logger := log.Extract(ctx) + logger.Debugln("successfully sent diagnostics") return nil } -func captureLogs(dest io.Writer) func() { - logOutput := log.Output() +func captureLogs(ctx context.Context, dest io.Writer) func() { + logger := log.Extract(ctx) + + logOutput := logger.Output() // will write to log output and dest mw := io.MultiWriter(logOutput, dest) - log.SetOutput(mw) + logger.SetOutput(mw) return func() { - log.SetOutput(logOutput) + logger.SetOutput(logOutput) } } diff --git a/cmd/run_internal_test.go b/cmd/run_internal_test.go index 425599cd..7ead4ebf 100644 --- a/cmd/run_internal_test.go +++ b/cmd/run_internal_test.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "encoding/json" "errors" "fmt" @@ -8,13 +9,13 @@ import ( "net/http" "net/http/httptest" "os" - "runtime" "testing" "time" cmdheartbeat "github.com/wakatime/wakatime-cli/cmd/heartbeat" "github.com/wakatime/wakatime-cli/pkg/exitcode" "github.com/wakatime/wakatime-cli/pkg/ini" + "github.com/wakatime/wakatime-cli/pkg/log" "github.com/wakatime/wakatime-cli/pkg/version" "github.com/spf13/viper" @@ -25,7 +26,7 @@ import ( func TestRunCmd(t *testing.T) { v := viper.New() - err := runCmd(v, false, false, func(_ *viper.Viper) (int, error) { + err := runCmd(context.Background(), v, false, false, func(_ context.Context, _ *viper.Viper) (int, error) { return exitcode.Success, nil }) @@ -35,14 +36,13 @@ func TestRunCmd(t *testing.T) { func TestRunCmd_Err(t *testing.T) { v := viper.New() - err := runCmd(v, false, false, func(_ *viper.Viper) (int, error) { + err := runCmd(context.Background(), v, false, false, func(_ context.Context, _ *viper.Viper) (int, error) { return exitcode.ErrGeneric, errors.New("fail") }) var errexitcode exitcode.Err require.ErrorAs(t, err, &errexitcode) - assert.Equal(t, exitcode.ErrGeneric, err.(exitcode.Err).Code) } @@ -99,7 +99,7 @@ func TestRunCmd_ErrOfflineEnqueue(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("plugin", "vim") - err := runCmd(v, true, false, func(_ *viper.Viper) (int, error) { + err := runCmd(context.Background(), v, true, false, func(_ context.Context, _ *viper.Viper) (int, error) { return exitcode.ErrGeneric, errors.New("fail") }) @@ -111,10 +111,7 @@ func TestRunCmd_ErrOfflineEnqueue(t *testing.T) { } func TestRunCmd_BackoffLoggedWithVerbose(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping because OS is windows.") - } - + ctx := context.Background() verbose := true testServerURL, router, tearDown := setupTestServer() @@ -155,16 +152,19 @@ func TestRunCmd_BackoffLoggedWithVerbose(t *testing.T) { v.Set("internal.backoff_retries", "1") v.Set("verbose", verbose) - _, _ = SetupLogging(v) + logger, err := SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) - err = runCmd(v, verbose, false, cmdheartbeat.Run) + err = runCmd(ctx, v, verbose, false, cmdheartbeat.Run) var errexitcode exitcode.Err require.ErrorAs(t, err, &errexitcode) - assert.Equal(t, exitcode.ErrBackoff, err.(exitcode.Err).Code) - assert.Equal(t, 0, numCalls) output, err := io.ReadAll(logFile) @@ -174,10 +174,7 @@ func TestRunCmd_BackoffLoggedWithVerbose(t *testing.T) { } func TestRunCmd_BackoffNotLogged(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping because OS is windows.") - } - + ctx := context.Background() verbose := false testServerURL, router, tearDown := setupTestServer() @@ -218,9 +215,14 @@ func TestRunCmd_BackoffNotLogged(t *testing.T) { v.Set("internal.backoff_retries", "1") v.Set("verbose", verbose) - _, _ = SetupLogging(v) + logger, err := SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) - err = runCmd(v, verbose, false, cmdheartbeat.Run) + err = runCmd(ctx, v, verbose, false, cmdheartbeat.Run) var errexitcode exitcode.Err @@ -239,7 +241,7 @@ func TestParseConfigFiles(t *testing.T) { v.Set("config", "testdata/.wakatime.cfg") v.Set("internal-config", "testdata/.wakatime-internal.cfg") - err := parseConfigFiles(v) + err := parseConfigFiles(context.Background(), v) require.NoError(t, err) assert.Equal(t, "true", v.GetString("settings.debug")) diff --git a/cmd/run_test.go b/cmd/run_test.go index 211c918a..b5104b59 100644 --- a/cmd/run_test.go +++ b/cmd/run_test.go @@ -1,6 +1,7 @@ package cmd_test import ( + "context" "encoding/json" "errors" "fmt" @@ -13,6 +14,7 @@ import ( "github.com/wakatime/wakatime-cli/cmd" "github.com/wakatime/wakatime-cli/pkg/exitcode" + "github.com/wakatime/wakatime-cli/pkg/log" "github.com/wakatime/wakatime-cli/pkg/offline" "github.com/wakatime/wakatime-cli/pkg/version" @@ -59,12 +61,12 @@ func TestRunCmd_Err(t *testing.T) { var cmdNumCalls int - cmdFn := func(_ *viper.Viper) (int, error) { + cmdFn := func(_ context.Context, _ *viper.Viper) (int, error) { cmdNumCalls++ return 42, errors.New("fail") } - err = cmd.RunCmd(v, false, false, cmdFn) + err = cmd.RunCmd(context.Background(), v, false, false, cmdFn) require.Error(t, err) var errexitcode exitcode.Err @@ -96,28 +98,21 @@ func TestRunCmd_Verbose_Err(t *testing.T) { defer offlineQueueFile.Close() - logFile, err := os.CreateTemp(tmpDir, "") - require.NoError(t, err) - - defer logFile.Close() - v := viper.New() v.Set("api-url", testServerURL) v.Set("entity", "/path/to/file") v.Set("key", "00000000-0000-4000-8000-000000000000") - v.Set("log-file", logFile.Name()) - v.Set("log-to-stdout", true) v.Set("offline-queue-file", offlineQueueFile.Name()) v.Set("plugin", "vim") var cmdNumCalls int - cmdFn := func(_ *viper.Viper) (int, error) { + cmdFn := func(_ context.Context, _ *viper.Viper) (int, error) { cmdNumCalls++ return 42, errors.New("fail") } - err = cmd.RunCmd(v, true, false, cmdFn) + err = cmd.RunCmd(context.Background(), v, true, false, cmdFn) var errexitcode exitcode.Err @@ -132,6 +127,8 @@ func TestRunCmd_SendDiagnostics_Err(t *testing.T) { testServerURL, router, tearDown := setupTestServer() defer tearDown() + ctx := context.Background() + var numCalls int router.HandleFunc("/plugins/errors", func(w http.ResponseWriter, req *http.Request) { @@ -185,28 +182,26 @@ func TestRunCmd_SendDiagnostics_Err(t *testing.T) { defer offlineQueueFile.Close() - logFile, err := os.CreateTemp(tmpDir, "") - require.NoError(t, err) - - defer logFile.Close() - v := viper.New() v.Set("api-url", testServerURL) v.Set("entity", "/path/to/file") v.Set("key", "00000000-0000-4000-8000-000000000000") - v.Set("log-file", logFile.Name()) - v.Set("log-to-stdout", true) v.Set("offline-queue-file", offlineQueueFile.Name()) v.Set("plugin", "vim") + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + ctx = log.ToContext(ctx, logger) + var cmdNumCalls int - cmdFn := func(_ *viper.Viper) (int, error) { + cmdFn := func(_ context.Context, _ *viper.Viper) (int, error) { cmdNumCalls++ return 42, errors.New("fail") } - err = cmd.RunCmd(v, true, true, cmdFn) + err = cmd.RunCmd(ctx, v, true, true, cmdFn) var errexitcode exitcode.Err @@ -221,6 +216,8 @@ func TestRunCmd_SendDiagnostics_Panic(t *testing.T) { testServerURL, router, tearDown := setupTestServer() defer tearDown() + ctx := context.Background() + var numCalls int router.HandleFunc("/plugins/errors", func(w http.ResponseWriter, req *http.Request) { @@ -275,29 +272,27 @@ func TestRunCmd_SendDiagnostics_Panic(t *testing.T) { defer offlineQueueFile.Close() - logFile, err := os.CreateTemp(tmpDir, "") - require.NoError(t, err) - - defer logFile.Close() - v := viper.New() v.Set("api-url", testServerURL) v.Set("entity", "/path/to/file") v.Set("key", "00000000-0000-4000-8000-000000000000") - v.Set("log-file", logFile.Name()) - v.Set("log-to-stdout", true) v.Set("offline-queue-file", offlineQueueFile.Name()) v.Set("plugin", "vim") + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + ctx = log.ToContext(ctx, logger) + var cmdNumCalls int - cmdFn := func(_ *viper.Viper) (int, error) { + cmdFn := func(_ context.Context, _ *viper.Viper) (int, error) { cmdNumCalls++ panic("fail") } - err = cmd.RunCmd(v, true, false, cmdFn) + err = cmd.RunCmd(ctx, v, true, false, cmdFn) var errexitcode exitcode.Err @@ -312,6 +307,8 @@ func TestRunCmd_SendDiagnostics_NoLogs_Panic(t *testing.T) { testServerURL, router, tearDown := setupTestServer() defer tearDown() + ctx := context.Background() + var numCalls int router.HandleFunc("/plugins/errors", func(w http.ResponseWriter, req *http.Request) { @@ -364,29 +361,27 @@ func TestRunCmd_SendDiagnostics_NoLogs_Panic(t *testing.T) { defer offlineQueueFile.Close() - logFile, err := os.CreateTemp(tmpDir, "") - require.NoError(t, err) - - defer logFile.Close() - v := viper.New() v.Set("api-url", testServerURL) v.Set("entity", "/path/to/file") v.Set("key", "00000000-0000-4000-8000-000000000000") - v.Set("log-file", logFile.Name()) - v.Set("log-to-stdout", true) v.Set("offline-queue-file", offlineQueueFile.Name()) v.Set("plugin", "vim") + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + ctx = log.ToContext(ctx, logger) + var cmdNumCalls int - cmdFn := func(_ *viper.Viper) (int, error) { + cmdFn := func(_ context.Context, _ *viper.Viper) (int, error) { cmdNumCalls++ panic("fail") } - err = cmd.RunCmd(v, false, false, cmdFn) + err = cmd.RunCmd(ctx, v, false, false, cmdFn) var errexitcode exitcode.Err @@ -401,6 +396,8 @@ func TestRunCmd_SendDiagnostics_WakaError(t *testing.T) { testServerURL, router, tearDown := setupTestServer() defer tearDown() + ctx := context.Background() + var numCalls int router.HandleFunc("/plugins/errors", func(w http.ResponseWriter, req *http.Request) { @@ -454,28 +451,26 @@ func TestRunCmd_SendDiagnostics_WakaError(t *testing.T) { defer offlineQueueFile.Close() - logFile, err := os.CreateTemp(tmpDir, "") - require.NoError(t, err) - - defer logFile.Close() - v := viper.New() v.Set("api-url", testServerURL) v.Set("entity", "/path/to/file") v.Set("key", "00000000-0000-4000-8000-000000000000") - v.Set("log-file", logFile.Name()) - v.Set("log-to-stdout", true) v.Set("offline-queue-file", offlineQueueFile.Name()) v.Set("plugin", "vim") + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + ctx = log.ToContext(ctx, logger) + var cmdNumCalls int - cmdFn := func(_ *viper.Viper) (int, error) { + cmdFn := func(_ context.Context, _ *viper.Viper) (int, error) { cmdNumCalls++ return 42, offline.ErrOpenDB{Err: errors.New("fail")} } - err = cmd.RunCmd(v, false, false, cmdFn) + err = cmd.RunCmd(ctx, v, false, false, cmdFn) var errexitcode exitcode.Err @@ -523,11 +518,6 @@ func TestRunCmdWithOfflineSync(t *testing.T) { version.Arch = "some architecture" version.Version = "some version" - logFile, err := os.CreateTemp(t.TempDir(), "") - require.NoError(t, err) - - defer logFile.Close() - // setup test queue offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) @@ -561,20 +551,18 @@ func TestRunCmdWithOfflineSync(t *testing.T) { v.Set("api-url", testServerURL) v.Set("entity", "/path/to/file") v.Set("key", "00000000-0000-4000-8000-000000000000") - v.Set("log-file", logFile.Name()) - v.Set("log-to-stdout", true) v.Set("offline-queue-file", offlineQueueFile.Name()) v.SetDefault("sync-offline-activity", 24) v.Set("plugin", "vim") var cmdNumCalls int - cmdFn := func(_ *viper.Viper) (int, error) { + cmdFn := func(_ context.Context, _ *viper.Viper) (int, error) { cmdNumCalls++ return exitcode.Success, nil } - err = cmd.RunCmdWithOfflineSync(v, false, false, cmdFn) + err = cmd.RunCmdWithOfflineSync(context.Background(), v, false, false, cmdFn) require.NoError(t, err) assert.Equal(t, 1, cmdNumCalls) @@ -602,7 +590,7 @@ func TestRunCmdWithOfflineSync(t *testing.T) { err = db.Close() require.NoError(t, err) - require.Len(t, stored, 0) + assert.Len(t, stored, 0) assert.Eventually(t, func() bool { return numCalls == 1 }, time.Second, 50*time.Millisecond) } diff --git a/cmd/today/today.go b/cmd/today/today.go index bae7af60..74a1a7e2 100644 --- a/cmd/today/today.go +++ b/cmd/today/today.go @@ -1,6 +1,7 @@ package today import ( + "context" "fmt" cmdapi "github.com/wakatime/wakatime-cli/cmd/api" @@ -14,8 +15,8 @@ import ( ) // Run executes the today command. -func Run(v *viper.Viper) (int, error) { - output, err := Today(v) +func Run(ctx context.Context, v *viper.Viper) (int, error) { + output, err := Today(ctx, v) if err != nil { if errwaka, ok := err.(wakaerror.Error); ok { return errwaka.ExitCode(), fmt.Errorf("today fetch failed: %s", errwaka.Message()) @@ -27,15 +28,17 @@ func Run(v *viper.Viper) (int, error) { ) } - log.Debugln("successfully fetched today for status bar") + logger := log.Extract(ctx) + + logger.Debugln("successfully fetched today for status bar") fmt.Println(output) return exitcode.Success, nil } // Today returns a rendered summary of today's coding activity. -func Today(v *viper.Viper) (string, error) { - paramAPI, err := params.LoadAPIParams(v) +func Today(ctx context.Context, v *viper.Viper) (string, error) { + paramAPI, err := params.LoadAPIParams(ctx, v) if err != nil { return "", fmt.Errorf("failed to load API parameters: %w", err) } @@ -45,12 +48,12 @@ func Today(v *viper.Viper) (string, error) { return "", fmt.Errorf("failed to load status bar parameters: %w", err) } - apiClient, err := cmdapi.NewClient(paramAPI) + apiClient, err := cmdapi.NewClient(ctx, paramAPI) if err != nil { return "", fmt.Errorf("failed to initialize api client: %w", err) } - s, err := apiClient.Today() + s, err := apiClient.Today(ctx) if err != nil { return "", fmt.Errorf("failed fetching today from api: %w", err) } diff --git a/cmd/today/today_test.go b/cmd/today/today_test.go index c6ab76f3..e1581589 100644 --- a/cmd/today/today_test.go +++ b/cmd/today/today_test.go @@ -1,6 +1,7 @@ package today_test import ( + "context" "errors" "fmt" "io" @@ -59,7 +60,7 @@ func TestToday(t *testing.T) { v.Set("api-url", testServerURL) v.Set("plugin", plugin) - output, err := today.Today(v) + output, err := today.Today(context.Background(), v) require.NoError(t, err) assert.Equal(t, "10 secs", output) @@ -83,7 +84,7 @@ func TestToday_ErrApi(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("api-url", testServerURL) - _, err := today.Today(v) + _, err := today.Today(context.Background(), v) require.Error(t, err) var errapi api.Err @@ -116,7 +117,7 @@ func TestToday_ErrAuth(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("api-url", testServerURL) - _, err := today.Today(v) + _, err := today.Today(context.Background(), v) require.Error(t, err) var errauth api.ErrAuth @@ -150,7 +151,7 @@ func TestToday_ErrBadRequest(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("api-url", testServerURL) - _, err := today.Today(v) + _, err := today.Today(context.Background(), v) require.Error(t, err) var errbadRequest api.ErrBadRequest @@ -168,7 +169,7 @@ func TestToday_ErrBadRequest(t *testing.T) { func TestToday_ErrAuth_UnsetAPIKey(t *testing.T) { v := viper.New() - _, err := today.Today(v) + _, err := today.Today(context.Background(), v) require.Error(t, err) var errauth api.ErrAuth diff --git a/cmd/todaygoal/todaygoal.go b/cmd/todaygoal/todaygoal.go index 30402e9a..34c1efef 100644 --- a/cmd/todaygoal/todaygoal.go +++ b/cmd/todaygoal/todaygoal.go @@ -1,6 +1,7 @@ package todaygoal import ( + "context" "fmt" "regexp" @@ -26,8 +27,8 @@ type Params struct { } // Run executes the today-goal command. -func Run(v *viper.Viper) (int, error) { - output, err := Goal(v) +func Run(ctx context.Context, v *viper.Viper) (int, error) { + output, err := Goal(ctx, v) if err != nil { if errwaka, ok := err.(wakaerror.Error); ok { return errwaka.ExitCode(), fmt.Errorf("today goal fetch failed: %s", errwaka.Message()) @@ -39,25 +40,27 @@ func Run(v *viper.Viper) (int, error) { ) } - log.Debugln("successfully fetched today goal") + logger := log.Extract(ctx) + + logger.Debugln("successfully fetched today goal") fmt.Println(output) return exitcode.Success, nil } // Goal returns total time of given goal id for today's coding activity. -func Goal(v *viper.Viper) (string, error) { - params, err := LoadParams(v) +func Goal(ctx context.Context, v *viper.Viper) (string, error) { + params, err := LoadParams(ctx, v) if err != nil { return "", fmt.Errorf("failed to load command parameters: %w", err) } - apiClient, err := cmdapi.NewClient(params.API) + apiClient, err := cmdapi.NewClient(ctx, params.API) if err != nil { return "", fmt.Errorf("failed to initialize api client: %w", err) } - g, err := apiClient.Goal(params.GoalID) + g, err := apiClient.Goal(ctx, params.GoalID) if err != nil { return "", fmt.Errorf("failed fetching todays goal from api: %w", err) } @@ -72,8 +75,8 @@ func Goal(v *viper.Viper) (string, error) { // LoadParams loads todaygoal config params from viper.Viper instance. Returns ErrAuth // if failed to retrieve api key. -func LoadParams(v *viper.Viper) (Params, error) { - paramAPI, err := params.LoadAPIParams(v) +func LoadParams(ctx context.Context, v *viper.Viper) (Params, error) { + paramAPI, err := params.LoadAPIParams(ctx, v) if err != nil { return Params{}, fmt.Errorf("failed to load API parameters: %w", err) } diff --git a/cmd/todaygoal/todaygoal_test.go b/cmd/todaygoal/todaygoal_test.go index 98403190..d6ff6ebf 100644 --- a/cmd/todaygoal/todaygoal_test.go +++ b/cmd/todaygoal/todaygoal_test.go @@ -1,6 +1,7 @@ package todaygoal_test import ( + "context" "errors" "fmt" "io" @@ -61,7 +62,7 @@ func TestGoal(t *testing.T) { v.Set("plugin", plugin) v.Set("today-goal", "00000000-0000-4000-8000-000000000000") - output, err := todaygoal.Goal(v) + output, err := todaygoal.Goal(context.Background(), v) require.NoError(t, err) assert.Equal(t, "3 hrs 23 mins", output) @@ -88,7 +89,7 @@ func TestGoal_ErrApi(t *testing.T) { v.Set("api-url", testServerURL) v.Set("today-goal", "00000000-0000-4000-8000-000000000000") - _, err := todaygoal.Goal(v) + _, err := todaygoal.Goal(context.Background(), v) require.Error(t, err) var errapi api.Err @@ -124,7 +125,7 @@ func TestGoal_ErrAuth(t *testing.T) { v.Set("api-url", testServerURL) v.Set("today-goal", "00000000-0000-4000-8000-000000000000") - _, err := todaygoal.Goal(v) + _, err := todaygoal.Goal(context.Background(), v) require.Error(t, err) var errauth api.ErrAuth @@ -160,7 +161,7 @@ func TestGoal_ErrBadRequest(t *testing.T) { v.Set("api-url", testServerURL) v.Set("today-goal", "00000000-0000-4000-8000-000000000000") - _, err := todaygoal.Goal(v) + _, err := todaygoal.Goal(context.Background(), v) require.Error(t, err) var errbadRequest api.ErrBadRequest @@ -178,7 +179,7 @@ func TestGoal_ErrBadRequest(t *testing.T) { func TestGoal_ErrAuth_UnsetAPIKey(t *testing.T) { v := viper.New() - _, err := todaygoal.Goal(v) + _, err := todaygoal.Goal(context.Background(), v) require.Error(t, err) var errauth api.ErrAuth @@ -198,7 +199,7 @@ func TestLoadParams_GoalID(t *testing.T) { v.Set("key", "00000000-0000-4000-8000-000000000000") v.Set("today-goal", "00000000-0000-4000-8000-000000000001") - params, err := todaygoal.LoadParams(v) + params, err := todaygoal.LoadParams(context.Background(), v) require.NoError(t, err) assert.Equal(t, "00000000-0000-4000-8000-000000000001", params.GoalID) diff --git a/cmd/version.go b/cmd/version.go index d8021a82..3d7a4b8e 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "github.com/wakatime/wakatime-cli/pkg/exitcode" @@ -9,7 +10,7 @@ import ( "github.com/spf13/viper" ) -func runVersion(v *viper.Viper) (int, error) { +func runVersion(_ context.Context, v *viper.Viper) (int, error) { if v.GetBool("verbose") { fmt.Printf( "wakatime-cli\n Version: %s\n Commit: %s\n Built: %s\n OS/Arch: %s/%s\n", diff --git a/main_test.go b/main_test.go index dbf965e5..f491d36d 100644 --- a/main_test.go +++ b/main_test.go @@ -4,6 +4,7 @@ package main_test import ( "bytes" + "context" "fmt" "io" "net/http" @@ -59,6 +60,8 @@ func testSendHeartbeats(t *testing.T, projectFolder, entity, p string) { apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int subfolders := project.CountSlashesInProjectFolder(projectFolder) @@ -71,7 +74,7 @@ func testSendHeartbeats(t *testing.T, projectFolder, entity, p string) { assert.Equal(t, []string{"application/json"}, req.Header["Accept"]) assert.Equal(t, []string{"application/json"}, req.Header["Content-Type"]) assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) - assert.Equal(t, []string{heartbeat.UserAgent("")}, req.Header["User-Agent"]) + assert.Equal(t, []string{heartbeat.UserAgent(ctx, "")}, req.Header["User-Agent"]) // check body expectedBodyTpl, err := os.ReadFile("testdata/api_heartbeats_request_template.json") @@ -86,7 +89,7 @@ func testSendHeartbeats(t *testing.T, projectFolder, entity, p string) { entityPath, p, subfolders, - heartbeat.UserAgent(""), + heartbeat.UserAgent(ctx, ""), ) body, err := io.ReadAll(req.Body) @@ -155,6 +158,8 @@ func TestSendHeartbeats_SecondaryApiKey(t *testing.T) { apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int rootPath, _ := filepath.Abs(".") @@ -168,7 +173,7 @@ func TestSendHeartbeats_SecondaryApiKey(t *testing.T) { assert.Equal(t, []string{"application/json"}, req.Header["Accept"]) assert.Equal(t, []string{"application/json"}, req.Header["Content-Type"]) assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAx"}, req.Header["Authorization"]) - assert.Equal(t, []string{heartbeat.UserAgent("")}, req.Header["User-Agent"]) + assert.Equal(t, []string{heartbeat.UserAgent(ctx, "")}, req.Header["User-Agent"]) // check body expectedBodyTpl, err := os.ReadFile("testdata/api_heartbeats_request_template.json") @@ -183,7 +188,7 @@ func TestSendHeartbeats_SecondaryApiKey(t *testing.T) { entityPath, "wakatime-cli", subfolders, - heartbeat.UserAgent(""), + heartbeat.UserAgent(ctx, ""), ) body, err := io.ReadAll(req.Body) @@ -246,6 +251,8 @@ func TestSendHeartbeats_ExtraHeartbeats(t *testing.T) { apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int router.HandleFunc("/users/current/heartbeats.bulk", func(w http.ResponseWriter, req *http.Request) { @@ -256,7 +263,7 @@ func TestSendHeartbeats_ExtraHeartbeats(t *testing.T) { assert.Equal(t, []string{"application/json"}, req.Header["Accept"]) assert.Equal(t, []string{"application/json"}, req.Header["Content-Type"]) assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) - assert.Equal(t, []string{heartbeat.UserAgent("")}, req.Header["User-Agent"]) + assert.Equal(t, []string{heartbeat.UserAgent(ctx, "")}, req.Header["User-Agent"]) var filename string @@ -324,7 +331,7 @@ func TestSendHeartbeats_ExtraHeartbeats(t *testing.T) { "--verbose", ) - offlineCount, err := offline.CountHeartbeats(offlineQueueFile.Name()) + offlineCount, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) assert.Equal(t, 1, offlineCount) @@ -336,6 +343,8 @@ func TestSendHeartbeats_ExtraHeartbeats_SyncLegacyOfflineActivity(t *testing.T) apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int router.HandleFunc("/users/current/heartbeats.bulk", func(w http.ResponseWriter, req *http.Request) { @@ -346,7 +355,7 @@ func TestSendHeartbeats_ExtraHeartbeats_SyncLegacyOfflineActivity(t *testing.T) assert.Equal(t, []string{"application/json"}, req.Header["Accept"]) assert.Equal(t, []string{"application/json"}, req.Header["Content-Type"]) assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) - assert.Equal(t, []string{heartbeat.UserAgent("")}, req.Header["User-Agent"]) + assert.Equal(t, []string{heartbeat.UserAgent(ctx, "")}, req.Header["User-Agent"]) var filename string @@ -454,7 +463,7 @@ func TestSendHeartbeats_ExtraHeartbeats_SyncLegacyOfflineActivity(t *testing.T) assert.NoFileExists(t, offlineQueueFileLegacy.Name()) - offlineCount, err := offline.CountHeartbeats(offlineQueueFile.Name()) + offlineCount, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) assert.Zero(t, offlineCount) @@ -466,6 +475,8 @@ func TestSendHeartbeats_Err(t *testing.T) { apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int projectFolder, err := filepath.Abs(".") @@ -481,7 +492,7 @@ func TestSendHeartbeats_Err(t *testing.T) { assert.Equal(t, []string{"application/json"}, req.Header["Accept"]) assert.Equal(t, []string{"application/json"}, req.Header["Content-Type"]) assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) - assert.Equal(t, []string{heartbeat.UserAgent("")}, req.Header["User-Agent"]) + assert.Equal(t, []string{heartbeat.UserAgent(ctx, "")}, req.Header["User-Agent"]) // check body expectedBodyTpl, err := os.ReadFile("testdata/api_heartbeats_request_template.json") @@ -496,7 +507,7 @@ func TestSendHeartbeats_Err(t *testing.T) { entityPath, "wakatime-cli", subfolders, - heartbeat.UserAgent(""), + heartbeat.UserAgent(ctx, ""), ) body, err := io.ReadAll(req.Body) @@ -561,6 +572,8 @@ func TestSendHeartbeats_ErrAuth_InvalidAPIKEY(t *testing.T) { apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int router.HandleFunc("/users/current/heartbeats.bulk", func(_ http.ResponseWriter, _ *http.Request) { @@ -611,7 +624,7 @@ func TestSendHeartbeats_ErrAuth_InvalidAPIKEY(t *testing.T) { assert.Empty(t, out) - count, err := offline.CountHeartbeats(offlineQueueFile.Name()) + count, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) assert.Equal(t, 1, count) @@ -620,6 +633,8 @@ func TestSendHeartbeats_ErrAuth_InvalidAPIKEY(t *testing.T) { } func TestSendHeartbeats_MalformedConfig(t *testing.T) { + ctx := context.Background() + tmpDir := t.TempDir() tmpInternalConfigFile, err := os.CreateTemp(tmpDir, "wakatime-internal.cfg") @@ -650,13 +665,15 @@ func TestSendHeartbeats_MalformedConfig(t *testing.T) { assert.Empty(t, out) - count, err := offline.CountHeartbeats(offlineQueueFile.Name()) + count, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) assert.Equal(t, 1, count) } func TestSendHeartbeats_MalformedInternalConfig(t *testing.T) { + ctx := context.Background() + tmpDir := t.TempDir() offlineQueueFile, err := os.CreateTemp(tmpDir, "") @@ -687,7 +704,7 @@ func TestSendHeartbeats_MalformedInternalConfig(t *testing.T) { assert.Empty(t, out) - count, err := offline.CountHeartbeats(offlineQueueFile.Name()) + count, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) require.NoError(t, err) assert.Equal(t, 1, count) @@ -702,6 +719,8 @@ func TestFileExperts(t *testing.T) { subfolders := project.CountSlashesInProjectFolder(projectFolder) + ctx := context.Background() + var numCalls int router.HandleFunc("/users/current/file_experts", @@ -713,7 +732,7 @@ func TestFileExperts(t *testing.T) { assert.Equal(t, []string{"application/json"}, req.Header["Accept"]) assert.Equal(t, []string{"application/json"}, req.Header["Content-Type"]) assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) - assert.Equal(t, []string{heartbeat.UserAgent("")}, req.Header["User-Agent"]) + assert.Equal(t, []string{heartbeat.UserAgent(ctx, "")}, req.Header["User-Agent"]) // check body expectedBodyTpl, err := os.ReadFile("testdata/api_file_experts_request_template.json") @@ -777,6 +796,8 @@ func TestTodayGoal(t *testing.T) { apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int tmpDir := t.TempDir() @@ -799,7 +820,7 @@ func TestTodayGoal(t *testing.T) { assert.Equal(t, http.MethodGet, req.Method) assert.Equal(t, []string{"application/json"}, req.Header["Accept"]) assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) - assert.Equal(t, []string{heartbeat.UserAgent("")}, req.Header["User-Agent"]) + assert.Equal(t, []string{heartbeat.UserAgent(ctx, "")}, req.Header["User-Agent"]) // write response f, err := os.Open("testdata/api_goals_id_response.json") @@ -830,6 +851,8 @@ func TestTodaySummary(t *testing.T) { apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int tmpDir := t.TempDir() @@ -851,7 +874,7 @@ func TestTodaySummary(t *testing.T) { assert.Equal(t, http.MethodGet, req.Method) assert.Equal(t, []string{"application/json"}, req.Header["Accept"]) assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) - assert.Equal(t, []string{heartbeat.UserAgent("")}, req.Header["User-Agent"]) + assert.Equal(t, []string{heartbeat.UserAgent(ctx, "")}, req.Header["User-Agent"]) // write response f, err := os.Open("testdata/api_statusbar_today_response.json") @@ -976,6 +999,8 @@ func TestPrintOfflineHeartbeats(t *testing.T) { apiURL, router, close := setupTestServer() defer close() + ctx := context.Background() + router.HandleFunc("/users/current/heartbeats.bulk", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) _, err := io.Copy(w, strings.NewReader("500 error test")) @@ -1054,7 +1079,7 @@ func TestPrintOfflineHeartbeats(t *testing.T) { offlineHeartbeatStr := fmt.Sprintf( string(offlineHeartbeat), entity, subfolders, - heartbeat.UserAgent(""), + heartbeat.UserAgent(ctx, ""), ) assert.Equal(t, offlineHeartbeatStr+"\n", out) @@ -1062,13 +1087,13 @@ func TestPrintOfflineHeartbeats(t *testing.T) { func TestUserAgent(t *testing.T) { out := runWakatimeCli(t, &bytes.Buffer{}, "--user-agent") - assert.Equal(t, fmt.Sprintf("%s\n", heartbeat.UserAgent("")), out) + assert.Equal(t, fmt.Sprintf("%s\n", heartbeat.UserAgent(context.Background(), "")), out) } func TestUserAgentWithPlugin(t *testing.T) { out := runWakatimeCli(t, &bytes.Buffer{}, "--user-agent", "--plugin", "Wakatime/1.0.4") - assert.Equal(t, fmt.Sprintf("%s\n", heartbeat.UserAgent("Wakatime/1.0.4")), out) + assert.Equal(t, fmt.Sprintf("%s\n", heartbeat.UserAgent(context.Background(), "Wakatime/1.0.4")), out) } func TestVersion(t *testing.T) { diff --git a/pkg/api/api.go b/pkg/api/api.go index 677a24df..4e0a6fa9 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "fmt" "net" @@ -62,7 +63,7 @@ func NewClient(baseURL string, opts ...Option) *Client { // Do executes c.doFunc(), which in turn allows wrapping c.client.Do() and manipulating // the request behavior of the api client. -func (c *Client) Do(req *http.Request) (*http.Response, error) { +func (c *Client) Do(ctx context.Context, req *http.Request) (*http.Response, error) { resp, err := c.doFunc(c, req) if err != nil { // don't set alternate host if there's a custom api url @@ -76,15 +77,16 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { } c.client = &http.Client{ - Transport: NewTransportWithHostVerificationDisabled(), + Transport: NewTransportWithHostVerificationDisabled(ctx), } req.URL.Host = BaseIPAddrv4 - if isLocalIPv6() { + if isLocalIPv6(ctx) { req.URL.Host = BaseIPAddrv6 } - log.Debugf("dns error, will retry with host ip '%s': %s", req.URL.Host, err) + logger := log.Extract(ctx) + logger.Debugf("dns error, will retry with host ip '%s': %s", req.URL.Host, err) resp, errRetry := c.doFunc(c, req) if errRetry != nil { @@ -97,16 +99,18 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { return resp, nil } -func isLocalIPv6() bool { +func isLocalIPv6(ctx context.Context) bool { + logger := log.Extract(ctx) + conn, err := net.Dial("udp", fmt.Sprintf("%s:80", BaseIPAddrv4)) if err != nil { - log.Warnf("failed dialing to detect default local ip address: %s", err) + logger.Warnf("failed dialing to detect default local ip address: %s", err) return true } defer func() { if err := conn.Close(); err != nil { - log.Debugf("failed to close connection to api wakatime: %s", err) + logger.Debugf("failed to close connection to api wakatime: %s", err) } }() diff --git a/pkg/api/diagnostic.go b/pkg/api/diagnostic.go index e8ce6fa1..f79faa03 100644 --- a/pkg/api/diagnostic.go +++ b/pkg/api/diagnostic.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -24,10 +25,17 @@ type diagnosticsBody struct { } // SendDiagnostics sends diagnostics to the WakaTime api. -func (c *Client) SendDiagnostics(plugin string, panicked bool, diagnostics ...diagnostic.Diagnostic) error { +func (c *Client) SendDiagnostics( + ctx context.Context, + plugin string, + panicked bool, + diagnostics ...diagnostic.Diagnostic, +) error { + logger := log.Extract(ctx) + url := c.baseURL + "/plugins/errors" - log.Debugf("sending diagnostic data to api at %s", url) + logger.Debugf("sending diagnostic data to api at %s", url) body := diagnosticsBody{ Architecture: version.Arch, @@ -62,7 +70,7 @@ func (c *Client) SendDiagnostics(plugin string, panicked bool, diagnostics ...di req.Header.Set("Content-Type", "application/json") - resp, err := c.Do(req) + resp, err := c.Do(ctx, req) if err != nil { return Err{Err: fmt.Errorf("failed making request to %q: %s", url, err)} } diff --git a/pkg/api/diagnostic_test.go b/pkg/api/diagnostic_test.go index 9fb37aeb..e654941c 100644 --- a/pkg/api/diagnostic_test.go +++ b/pkg/api/diagnostic_test.go @@ -1,6 +1,7 @@ package api_test import ( + "context" "encoding/json" "fmt" "io" @@ -75,7 +76,7 @@ func TestClient_SendDiagnostics(t *testing.T) { } c := api.NewClient(url) - err := c.SendDiagnostics("vim", false, diagnostics...) + err := c.SendDiagnostics(context.Background(), "vim", false, diagnostics...) require.NoError(t, err) assert.Eventually(t, func() bool { return numCalls == 1 }, time.Second, 50*time.Millisecond) diff --git a/pkg/api/fileexperts.go b/pkg/api/fileexperts.go index 46a779b4..3e8bd10e 100644 --- a/pkg/api/fileexperts.go +++ b/pkg/api/fileexperts.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -17,7 +18,9 @@ import ( // ErrRequest is returned upon request failure with no received response from api. // ErrAuth is returned upon receiving a 401 Unauthorized api response. // Err is returned on any other api response related error. -func (c *Client) FileExperts(heartbeats []heartbeat.Heartbeat) ([]heartbeat.Result, error) { +func (c *Client) FileExperts(ctx context.Context, heartbeats []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + url := c.baseURL + "/users/current/file_experts" // change from heartbeat.Heartbeat to fileexpert.Entity @@ -33,7 +36,7 @@ func (c *Client) FileExperts(heartbeats []heartbeat.Heartbeat) ([]heartbeat.Resu return nil, fmt.Errorf("failed to json encode body: %s", err) } - log.Debugf("file-experts: %s", string(data)) + logger.Debugf("file-experts: %s", string(data)) req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(data)) if err != nil { @@ -45,7 +48,7 @@ func (c *Client) FileExperts(heartbeats []heartbeat.Heartbeat) ([]heartbeat.Resu // set auth header here for every request due to multiple api key support setAuthHeader(req, heartbeats[0].APIKey) - resp, err := c.Do(req) + resp, err := c.Do(ctx, req) if err != nil { return nil, Err{Err: fmt.Errorf("failed making request to %q: %s", url, err)} } diff --git a/pkg/api/fileexperts_test.go b/pkg/api/fileexperts_test.go index 60acb785..d3252649 100644 --- a/pkg/api/fileexperts_test.go +++ b/pkg/api/fileexperts_test.go @@ -1,6 +1,7 @@ package api_test import ( + "context" "errors" "io" "net/http" @@ -17,6 +18,8 @@ import ( ) func TestClient_FileExperts(t *testing.T) { + ctx := context.Background() + tests := []int{ http.StatusOK, http.StatusAccepted, @@ -56,7 +59,7 @@ func TestClient_FileExperts(t *testing.T) { }) c := api.NewClient(url) - results, err := c.FileExperts([]heartbeat.Heartbeat{ + results, err := c.FileExperts(ctx, []heartbeat.Heartbeat{ { APIKey: "00000000-0000-4000-8000-000000000000", Entity: "/tmp/main.go", @@ -131,7 +134,7 @@ func TestClient_FileExperts_Err(t *testing.T) { }) c := api.NewClient(url) - _, err := c.FileExperts([]heartbeat.Heartbeat{ + _, err := c.FileExperts(context.Background(), []heartbeat.Heartbeat{ { APIKey: "00000000-0000-4000-8000-000000000000", Entity: "/tmp/main.go", @@ -160,7 +163,7 @@ func TestClient_FileExperts_ErrAuth(t *testing.T) { }) c := api.NewClient(url) - _, err := c.FileExperts([]heartbeat.Heartbeat{ + _, err := c.FileExperts(context.Background(), []heartbeat.Heartbeat{ { APIKey: "00000000-0000-4000-8000-000000000000", Entity: "/tmp/main.go", @@ -189,7 +192,7 @@ func TestClient_FileExperts_ErrBadRequest(t *testing.T) { }) c := api.NewClient(url) - _, err := c.FileExperts([]heartbeat.Heartbeat{ + _, err := c.FileExperts(context.Background(), []heartbeat.Heartbeat{ { APIKey: "00000000-0000-4000-8000-000000000000", Entity: "/tmp/main.go", @@ -207,7 +210,7 @@ func TestClient_FileExperts_ErrBadRequest(t *testing.T) { func TestClient_FileExperts_InvalidUrl(t *testing.T) { c := api.NewClient("invalid-url") - _, err := c.FileExperts([]heartbeat.Heartbeat{ + _, err := c.FileExperts(context.Background(), []heartbeat.Heartbeat{ { APIKey: "00000000-0000-4000-8000-000000000000", Entity: "/tmp/main.go", diff --git a/pkg/api/goal.go b/pkg/api/goal.go index 25c6520d..948683c3 100644 --- a/pkg/api/goal.go +++ b/pkg/api/goal.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "fmt" "io" @@ -14,7 +15,7 @@ import ( // ErrRequest is returned upon request failure with no received response from api. // ErrAuth is returned upon receiving a 401 Unauthorized api response. // Err is returned on any other api response related error. -func (c *Client) Goal(id string) (*goal.Goal, error) { +func (c *Client) Goal(ctx context.Context, id string) (*goal.Goal, error) { url := c.baseURL + "/users/current/goals/" + id req, err := http.NewRequest(http.MethodGet, url, nil) @@ -24,7 +25,7 @@ func (c *Client) Goal(id string) (*goal.Goal, error) { req.Header.Set("Content-Type", "application/json") - resp, err := c.Do(req) + resp, err := c.Do(ctx, req) if err != nil { return nil, Err{Err: fmt.Errorf("failed to make request to %q: %s", url, err)} } diff --git a/pkg/api/goal_test.go b/pkg/api/goal_test.go index 01359681..3daa0c75 100644 --- a/pkg/api/goal_test.go +++ b/pkg/api/goal_test.go @@ -1,6 +1,7 @@ package api_test import ( + "context" "errors" "fmt" "io" @@ -40,7 +41,7 @@ func TestClient_Goal(t *testing.T) { }) c := api.NewClient(u) - goal, err := c.Goal("00000000-0000-4000-8000-000000000000") + goal, err := c.Goal(context.Background(), "00000000-0000-4000-8000-000000000000") require.NoError(t, err) @@ -65,8 +66,10 @@ func TestClient_GoalWithTimeout(t *testing.T) { }) opts := []api.Option{api.WithTimeout(20 * time.Millisecond)} + c := api.NewClient(u, opts...) - _, err := c.Goal("00000000-0000-4000-8000-000000000000") + + _, err := c.Goal(context.Background(), "00000000-0000-4000-8000-000000000000") require.Error(t, err) errMsg := fmt.Sprintf("error %q does not contain string 'Timeout'", err) @@ -95,7 +98,8 @@ func TestClient_Goal_Err(t *testing.T) { }) c := api.NewClient(u) - _, err := c.Goal("00000000-0000-4000-8000-000000000000") + + _, err := c.Goal(context.Background(), "00000000-0000-4000-8000-000000000000") var apierr api.Err @@ -117,7 +121,8 @@ func TestClient_Goal_ErrAuth(t *testing.T) { }) c := api.NewClient(u) - _, err := c.Goal("00000000-0000-4000-8000-000000000000") + + _, err := c.Goal(context.Background(), "00000000-0000-4000-8000-000000000000") var errauth api.ErrAuth @@ -140,7 +145,8 @@ func TestClient_Goal_ErrBadRequest(t *testing.T) { }) c := api.NewClient(u) - _, err := c.Goal("00000000-0000-4000-8000-000000000000") + + _, err := c.Goal(context.Background(), "00000000-0000-4000-8000-000000000000") var errbadRequest api.ErrBadRequest @@ -150,7 +156,8 @@ func TestClient_Goal_ErrBadRequest(t *testing.T) { func TestClient_Goal_ErrInvalidUrl(t *testing.T) { c := api.NewClient("invalid-url") - _, err := c.Goal("00000000-0000-4000-8000-000000000000") + + _, err := c.Goal(context.Background(), "00000000-0000-4000-8000-000000000000") var apierr api.Err diff --git a/pkg/api/heartbeat.go b/pkg/api/heartbeat.go index 0f8b4c8b..76f594c2 100644 --- a/pkg/api/heartbeat.go +++ b/pkg/api/heartbeat.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -21,10 +22,12 @@ import ( // ErrRequest is returned upon request failure with no received response from api. // ErrAuth is returned upon receiving a 401 Unauthorized api response. // Err is returned on any other api response related error. -func (c *Client) SendHeartbeats(heartbeats []heartbeat.Heartbeat) ([]heartbeat.Result, error) { +func (c *Client) SendHeartbeats(ctx context.Context, heartbeats []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + url := c.baseURL + "/users/current/heartbeats.bulk" - log.Debugf("sending %d heartbeat(s) to api at %s", len(heartbeats), url) + logger.Debugf("sending %d heartbeat(s) to api at %s", len(heartbeats), url) var results []heartbeat.Result @@ -32,7 +35,7 @@ func (c *Client) SendHeartbeats(heartbeats []heartbeat.Heartbeat) ([]heartbeat.R keys := sortKeys(grouped) for _, k := range keys { - res, err := c.sendHeartbeats(url, grouped[k]) + res, err := c.sendHeartbeats(ctx, url, grouped[k]) if err != nil { return nil, err } @@ -43,13 +46,15 @@ func (c *Client) SendHeartbeats(heartbeats []heartbeat.Heartbeat) ([]heartbeat.R return results, nil } -func (c *Client) sendHeartbeats(url string, heartbeats []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - data, err := json.Marshal(heartbeats) +func (c *Client) sendHeartbeats(ctx context.Context, url string, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + + data, err := json.Marshal(hh) if err != nil { return nil, fmt.Errorf("failed to json encode body: %s", err) } - log.Debugf("heartbeats: %s", string(data)) + logger.Debugf("heartbeats: %s", string(data)) req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(data)) if err != nil { @@ -59,9 +64,9 @@ func (c *Client) sendHeartbeats(url string, heartbeats []heartbeat.Heartbeat) ([ req.Header.Set("Content-Type", "application/json") // set auth header here for every request due to multiple api key support - setAuthHeader(req, heartbeats[0].APIKey) + setAuthHeader(req, hh[0].APIKey) - resp, err := c.Do(req) + resp, err := c.Do(ctx, req) if err != nil { return nil, Err{Err: fmt.Errorf("failed making request to %q: %s", url, err)} } @@ -89,7 +94,7 @@ func (c *Client) sendHeartbeats(url string, heartbeats []heartbeat.Heartbeat) ([ )} } - results, err := ParseHeartbeatResponses(body) + results, err := ParseHeartbeatResponses(ctx, body) if err != nil { return nil, Err{Err: fmt.Errorf("failed parsing results from %q: %s", url, err)} } @@ -98,7 +103,7 @@ func (c *Client) sendHeartbeats(url string, heartbeats []heartbeat.Heartbeat) ([ } // ParseHeartbeatResponses parses the aggregated responses returned by the heartbeat bulk endpoint. -func ParseHeartbeatResponses(data []byte) ([]heartbeat.Result, error) { +func ParseHeartbeatResponses(ctx context.Context, data []byte) ([]heartbeat.Result, error) { var responsesBody struct { Responses [][]json.RawMessage `json:"responses"` } @@ -111,7 +116,7 @@ func ParseHeartbeatResponses(data []byte) ([]heartbeat.Result, error) { var results []heartbeat.Result for n, r := range responsesBody.Responses { - result, err := parseHeartbeatResponse(r) + result, err := parseHeartbeatResponse(ctx, r) if err != nil { return nil, fmt.Errorf("failed parsing result #%d: %s. body: %q", n, err, string(data)) } @@ -123,7 +128,7 @@ func ParseHeartbeatResponses(data []byte) ([]heartbeat.Result, error) { } // parseHeartbeatResponse parses one response of the aggregated responses returned by the heartbeat bulk endpoint. -func parseHeartbeatResponse(data []json.RawMessage) (heartbeat.Result, error) { +func parseHeartbeatResponse(ctx context.Context, data []json.RawMessage) (heartbeat.Result, error) { var result heartbeat.Result type responseBody struct { @@ -136,7 +141,7 @@ func parseHeartbeatResponse(data []json.RawMessage) (heartbeat.Result, error) { } if result.Status < http.StatusOK || result.Status > 299 { - resultErrors, err := parseHeartbeatResponseError(data[0]) + resultErrors, err := parseHeartbeatResponseError(ctx, data[0]) if err != nil { return heartbeat.Result{}, fmt.Errorf("failed to parse result errors: %s", err) } @@ -158,7 +163,9 @@ func parseHeartbeatResponse(data []json.RawMessage) (heartbeat.Result, error) { } // parseHeartbeatResponseError parses one error of the aggregated responses returned by the heartbeat bulk endpoint. -func parseHeartbeatResponseError(data json.RawMessage) ([]string, error) { +func parseHeartbeatResponseError(ctx context.Context, data json.RawMessage) ([]string, error) { + logger := log.Extract(ctx) + var errs []string type responseBodyErr struct { @@ -171,7 +178,7 @@ func parseHeartbeatResponseError(data json.RawMessage) ([]string, error) { err := json.Unmarshal(data, &responseBodyErr{Error: &resultError}) if err != nil { - log.Debugf("failed to parse json heartbeat error or 'error' key not found: %s", err) + logger.Debugf("failed to parse json heartbeat error or 'error' key not found: %s", err) } if resultError != "" { @@ -184,7 +191,7 @@ func parseHeartbeatResponseError(data json.RawMessage) ([]string, error) { err = json.Unmarshal(data, &responseBodyErr{Errors: &resultErrors}) if err != nil { - log.Debugf("failed to parse json heartbeat errors or 'errors' key not found: %s", err) + logger.Debugf("failed to parse json heartbeat errors or 'errors' key not found: %s", err) } if resultErrors == nil { diff --git a/pkg/api/heartbeat_test.go b/pkg/api/heartbeat_test.go index 5bdfcc15..02c2c2aa 100644 --- a/pkg/api/heartbeat_test.go +++ b/pkg/api/heartbeat_test.go @@ -1,6 +1,7 @@ package api_test import ( + "context" "errors" "io" "net/http" @@ -55,7 +56,7 @@ func TestClient_SendHeartbeats(t *testing.T) { }) c := api.NewClient(url) - results, err := c.SendHeartbeats(testHeartbeats()) + results, err := c.SendHeartbeats(context.Background(), testHeartbeats()) require.NoError(t, err) // check via assert.Equal on complete slice here, to assert exact order of results, @@ -127,7 +128,7 @@ func TestClient_SendHeartbeats_MultipleApiKey(t *testing.T) { hh := testHeartbeats() hh[1].APIKey = "00000000-0000-4000-8000-000000000001" - _, err := c.SendHeartbeats(hh) + _, err := c.SendHeartbeats(context.Background(), hh) require.NoError(t, err) assert.Eventually(t, func() bool { return numCalls == 2 }, time.Second, 50*time.Millisecond) @@ -146,7 +147,8 @@ func TestClient_SendHeartbeats_Err(t *testing.T) { }) c := api.NewClient(url) - _, err := c.SendHeartbeats(testHeartbeats()) + + _, err := c.SendHeartbeats(context.Background(), testHeartbeats()) var errapi api.Err @@ -168,7 +170,8 @@ func TestClient_SendHeartbeats_ErrAuth(t *testing.T) { }) c := api.NewClient(url) - _, err := c.SendHeartbeats(testHeartbeats()) + + _, err := c.SendHeartbeats(context.Background(), testHeartbeats()) var errauth api.ErrAuth @@ -190,7 +193,8 @@ func TestClient_SendHeartbeats_ErrBadRequest(t *testing.T) { }) c := api.NewClient(url) - _, err := c.SendHeartbeats(testHeartbeats()) + + _, err := c.SendHeartbeats(context.Background(), testHeartbeats()) var errbadRequest api.ErrBadRequest @@ -201,7 +205,8 @@ func TestClient_SendHeartbeats_ErrBadRequest(t *testing.T) { func TestClient_SendHeartbeats_InvalidUrl(t *testing.T) { c := api.NewClient("invalid-url") - _, err := c.SendHeartbeats(testHeartbeats()) + + _, err := c.SendHeartbeats(context.Background(), testHeartbeats()) var apierr api.Err @@ -212,7 +217,7 @@ func TestParseHeartbeatResponses(t *testing.T) { data, err := os.ReadFile("testdata/api_heartbeats_response.json") require.NoError(t, err) - results, err := api.ParseHeartbeatResponses(data) + results, err := api.ParseHeartbeatResponses(context.Background(), data) require.NoError(t, err) // check via assert.Equal on complete slice here, to assert exact order of results, @@ -260,7 +265,7 @@ func TestParseHeartbeatResponses_Error(t *testing.T) { data, err := os.ReadFile("testdata/api_heartbeats_response_error.json") require.NoError(t, err) - results, err := api.ParseHeartbeatResponses(data) + results, err := api.ParseHeartbeatResponses(context.Background(), data) require.NoError(t, err) // asserting here the exact order of results, which is assumed to exactly match the request order @@ -281,7 +286,7 @@ func TestParseHeartbeatResponses_Errors(t *testing.T) { data, err := os.ReadFile("testdata/api_heartbeats_response_errors.json") require.NoError(t, err) - results, err := api.ParseHeartbeatResponses(data) + results, err := api.ParseHeartbeatResponses(context.Background(), data) require.NoError(t, err) // asserting here the exact order of results, which is assumed to exactly match the request order diff --git a/pkg/api/option.go b/pkg/api/option.go index 49e455a1..6f554b45 100644 --- a/pkg/api/option.go +++ b/pkg/api/option.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/x509" "fmt" "net/http" @@ -78,23 +79,25 @@ func WithNTLM(creds string) (Option, error) { } // WithNTLMRequestRetry will, upon request failure, retry with ntlm authentication. -func WithNTLMRequestRetry(creds string) (Option, error) { +func WithNTLMRequestRetry(ctx context.Context, creds string) (Option, error) { withNTLM, err := WithNTLM(creds) if err != nil { return Option(func(*Client) {}), err } return func(c *Client) { + logger := log.Extract(ctx) + next := c.doFunc c.doFunc = func(cl *Client, req *http.Request) (*http.Response, error) { resp, err := next(c, req) if err != nil { - log.Errorf("request to api failed with error %q. Will retry with ntlm auth", err) + logger.Errorf("request to api failed with error %q. Will retry with ntlm auth", err) clCopy := cl withNTLM(clCopy) - return clCopy.Do(req) + return clCopy.Do(ctx, req) } return resp, nil @@ -177,8 +180,8 @@ func WithTimezone(timezone string) Option { // WithUserAgent sets the User-Agent header on all requests, including the passed // in value for plugin. -func WithUserAgent(plugin string) Option { - userAgent := heartbeat.UserAgent(plugin) +func WithUserAgent(ctx context.Context, plugin string) Option { + userAgent := heartbeat.UserAgent(ctx, plugin) return func(c *Client) { next := c.doFunc diff --git a/pkg/api/option_test.go b/pkg/api/option_test.go index e33ea658..ed2285cb 100644 --- a/pkg/api/option_test.go +++ b/pkg/api/option_test.go @@ -1,6 +1,7 @@ package api_test import ( + "context" "encoding/base64" "fmt" "net/http" @@ -19,6 +20,8 @@ import ( ) func TestOption_WithAuth(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { User string AuthHeaderValue string @@ -55,7 +58,8 @@ func TestOption_WithAuth(t *testing.T) { require.NoError(t, err) c := api.NewClient("", []api.Option{withAuth}...) - resp, err := c.Do(req) + + resp, err := c.Do(ctx, req) require.NoError(t, err) defer resp.Body.Close() @@ -83,7 +87,8 @@ func TestOption_WithHostname(t *testing.T) { require.NoError(t, err) c := api.NewClient("", opts...) - resp, err := c.Do(req) + + resp, err := c.Do(context.Background(), req) require.NoError(t, err) defer resp.Body.Close() @@ -109,7 +114,8 @@ func TestOption_WithInvalidHostname(t *testing.T) { require.NoError(t, err) c := api.NewClient("", opts...) - resp, err := c.Do(req) + + resp, err := c.Do(context.Background(), req) require.NoError(t, err) defer resp.Body.Close() @@ -118,6 +124,8 @@ func TestOption_WithInvalidHostname(t *testing.T) { } func TestOption_WithNTLM(t *testing.T) { + ctx := context.Background() + tests := map[string]string{ "default": `domain\\john:123456`, "useronly": `domain\\john`, @@ -159,7 +167,8 @@ func TestOption_WithNTLM(t *testing.T) { require.NoError(t, err) c := api.NewClient("", []api.Option{withNTLM}...) - resp, err := c.Do(req) + + resp, err := c.Do(ctx, req) require.NoError(t, err) defer resp.Body.Close() @@ -173,6 +182,8 @@ func TestOption_WithNTLMRequestRetry(t *testing.T) { url, router, close := setupTestServer() defer close() + ctx := context.Background() + var numCalls int router.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { @@ -211,14 +222,15 @@ func TestOption_WithNTLMRequestRetry(t *testing.T) { assert.Equal(t, []string{"NTLM " + base64.StdEncoding.EncodeToString(msg)}, authHeader) }) - withNTLMRetry, err := api.WithNTLMRequestRetry(`domain\\john:secret`) + withNTLMRetry, err := api.WithNTLMRequestRetry(ctx, `domain\\john:secret`) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) c := api.NewClient("", []api.Option{withNTLMRetry}...) - resp, err := c.Do(req) + + resp, err := c.Do(ctx, req) require.NoError(t, err) defer resp.Body.Close() @@ -245,7 +257,8 @@ func TestOption_WithProxy(t *testing.T) { require.NoError(t, err) c := api.NewClient("", opts...) - resp, err := c.Do(req) + + resp, err := c.Do(context.Background(), req) require.NoError(t, err) defer resp.Body.Close() @@ -257,6 +270,8 @@ func TestOption_WithUserAgent(t *testing.T) { url, router, tearDown := setupTestServer() defer tearDown() + ctx := context.Background() + var numCalls int router.HandleFunc("/", func(_ http.ResponseWriter, req *http.Request) { @@ -276,13 +291,14 @@ func TestOption_WithUserAgent(t *testing.T) { numCalls++ }) - opts := []api.Option{api.WithUserAgent("testplugin")} + opts := []api.Option{api.WithUserAgent(ctx, "testplugin")} req, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) c := api.NewClient("", opts...) - resp, err := c.Do(req) + + resp, err := c.Do(ctx, req) require.NoError(t, err) defer resp.Body.Close() @@ -294,6 +310,8 @@ func TestOption_WithUserAgentUnknownPlugin(t *testing.T) { url, router, tearDown := setupTestServer() defer tearDown() + ctx := context.Background() + var numCalls int router.HandleFunc("/", func(_ http.ResponseWriter, req *http.Request) { @@ -313,13 +331,14 @@ func TestOption_WithUserAgentUnknownPlugin(t *testing.T) { numCalls++ }) - opts := []api.Option{api.WithUserAgent("")} + opts := []api.Option{api.WithUserAgent(ctx, "")} req, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) c := api.NewClient("", opts...) - resp, err := c.Do(req) + + resp, err := c.Do(ctx, req) require.NoError(t, err) defer resp.Body.Close() @@ -345,7 +364,8 @@ func TestOption_WithTimezone(t *testing.T) { require.NoError(t, err) c := api.NewClient("", opts...) - resp, err := c.Do(req) + + resp, err := c.Do(context.Background(), req) require.NoError(t, err) defer resp.Body.Close() diff --git a/pkg/api/statusbar.go b/pkg/api/statusbar.go index e17d8f48..b543c603 100644 --- a/pkg/api/statusbar.go +++ b/pkg/api/statusbar.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "fmt" "io" @@ -14,7 +15,7 @@ import ( // ErrRequest is returned upon request failure with no received response from api. // ErrAuth is returned upon receiving a 401 Unauthorized api response. // Err is returned on any other api response related error. -func (c *Client) Today() (*summary.Summary, error) { +func (c *Client) Today(ctx context.Context) (*summary.Summary, error) { url := c.baseURL + "/users/current/statusbar/today" req, err := http.NewRequest(http.MethodGet, url, nil) @@ -27,7 +28,7 @@ func (c *Client) Today() (*summary.Summary, error) { q := req.URL.Query() req.URL.RawQuery = q.Encode() - resp, err := c.Do(req) + resp, err := c.Do(ctx, req) if err != nil { return nil, Err{fmt.Errorf("failed to make request to %q: %s", url, err)} } diff --git a/pkg/api/statusbar_test.go b/pkg/api/statusbar_test.go index 5f497999..4f0b24d1 100644 --- a/pkg/api/statusbar_test.go +++ b/pkg/api/statusbar_test.go @@ -1,6 +1,7 @@ package api_test import ( + "context" "errors" "fmt" "io" @@ -40,8 +41,8 @@ func TestClient_StatusBar(t *testing.T) { }) c := api.NewClient(u) - s, err := c.Today() + s, err := c.Today(context.Background()) require.NoError(t, err) assert.Equal(t, s, testSummary()) @@ -65,7 +66,8 @@ func TestClient_StatusBarWithTimeout(t *testing.T) { opts := []api.Option{api.WithTimeout(20 * time.Millisecond)} c := api.NewClient(u, opts...) - _, err := c.Today() + + _, err := c.Today(context.Background()) require.Error(t, err) errMsg := fmt.Sprintf("error %q does not contain string 'Timeout'", err) @@ -93,7 +95,8 @@ func TestClient_StatusBar_Err(t *testing.T) { }) c := api.NewClient(u) - _, err := c.Today() + + _, err := c.Today(context.Background()) var apierr api.Err @@ -114,7 +117,8 @@ func TestClient_StatusBar_ErrAuth(t *testing.T) { }) c := api.NewClient(u) - _, err := c.Today() + + _, err := c.Today(context.Background()) var errauth api.ErrAuth @@ -136,7 +140,8 @@ func TestClient_StatusBar_ErrBadRequest(t *testing.T) { }) c := api.NewClient(u) - _, err := c.Today() + + _, err := c.Today(context.Background()) var errbadRequest api.ErrBadRequest @@ -146,7 +151,8 @@ func TestClient_StatusBar_ErrBadRequest(t *testing.T) { func TestClient_StatusBar_InvalidUrl(t *testing.T) { c := api.NewClient("invalid-url") - _, err := c.Today() + + _, err := c.Today(context.Background()) var apierr api.Err diff --git a/pkg/api/transport.go b/pkg/api/transport.go index 1ec3f7c7..d7b75c24 100644 --- a/pkg/api/transport.go +++ b/pkg/api/transport.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/tls" "crypto/x509" "net/http" @@ -99,12 +100,12 @@ func NewTransport() *http.Transport { } // NewTransportWithHostVerificationDisabled initializes a new http.Transport with disabled host verification. -func NewTransportWithHostVerificationDisabled() *http.Transport { +func NewTransportWithHostVerificationDisabled(ctx context.Context) *http.Transport { t := NewTransport() t.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, - RootCAs: CACerts(), + RootCAs: CACerts(ctx), ServerName: serverName, } @@ -121,14 +122,16 @@ func LazyCreateNewTransport(c *Client) *http.Transport { } // CACerts returns a root cert pool with the system's cacerts and LetsEncrypt's root certs. -func CACerts() *x509.CertPool { - certs, err := loadSystemRoots() +func CACerts(ctx context.Context) *x509.CertPool { + logger := log.Extract(ctx) + + certs, err := loadSystemRoots(ctx) if err != nil { - log.Warnf("unable to use system cert pool: %s", err) + logger.Warnf("unable to use system cert pool: %s", err) } if certs == nil { - log.Warnf("system cert pool empty") + logger.Warnf("system cert pool empty") certs = x509.NewCertPool() } diff --git a/pkg/api/transport_other.go b/pkg/api/transport_other.go index 2997b4e4..887f0d4f 100644 --- a/pkg/api/transport_other.go +++ b/pkg/api/transport_other.go @@ -3,9 +3,10 @@ package api import ( + "context" "crypto/x509" ) -func loadSystemRoots() (*x509.CertPool, error) { +func loadSystemRoots(_ context.Context) (*x509.CertPool, error) { return x509.SystemCertPool() } diff --git a/pkg/api/transport_windows.go b/pkg/api/transport_windows.go index 01ef90a4..1b89bd50 100644 --- a/pkg/api/transport_windows.go +++ b/pkg/api/transport_windows.go @@ -3,6 +3,7 @@ package api import ( + "context" "crypto/x509" "runtime/debug" "syscall" @@ -11,10 +12,12 @@ import ( "github.com/wakatime/wakatime-cli/pkg/log" ) -func loadSystemRoots() (*x509.CertPool, error) { +func loadSystemRoots(ctx context.Context) (*x509.CertPool, error) { defer func() { + logger := log.Extract(ctx) + if err := recover(); err != nil { - log.Errorf("failed to load system roots on Windows. panicked: %v. Stack: %s", err, string(debug.Stack())) + logger.Errorf("failed to load system roots on Windows. panicked: %v. Stack: %s", err, string(debug.Stack())) } }() diff --git a/pkg/apikey/apikey.go b/pkg/apikey/apikey.go index f6675483..8f23eb71 100644 --- a/pkg/apikey/apikey.go +++ b/pkg/apikey/apikey.go @@ -1,6 +1,8 @@ package apikey import ( + "context" + "github.com/wakatime/wakatime-cli/pkg/heartbeat" "github.com/wakatime/wakatime-cli/pkg/log" "github.com/wakatime/wakatime-cli/pkg/regex" @@ -27,11 +29,12 @@ type MapPattern struct { // for a heartbeat following the provided configurations. func WithReplacing(config Config) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute api key replacing") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute api key replacing") for n, h := range hh { - result, ok := MatchPattern(h.Entity, config.MapPatterns) + result, ok := MatchPattern(ctx, h.Entity, config.MapPatterns) if ok { hh[n].APIKey = result @@ -41,20 +44,22 @@ func WithReplacing(config Config) heartbeat.HandleOption { hh[n].APIKey = config.DefaultAPIKey } - return next(hh) + return next(ctx, hh) } } } // MatchPattern matches regex against entity's path to find alternate api key. -func MatchPattern(fp string, patterns []MapPattern) (string, bool) { +func MatchPattern(ctx context.Context, fp string, patterns []MapPattern) (string, bool) { + logger := log.Extract(ctx) + for _, pattern := range patterns { - if pattern.Regex.MatchString(fp) { - log.Debugf("api key pattern %q matched path %q", pattern.Regex.String(), fp) + if pattern.Regex.MatchString(ctx, fp) { + logger.Debugf("api key pattern %q matched path %q", pattern.Regex.String(), fp) return pattern.APIKey, true } - log.Debugf("api key pattern %q did not match path %q", pattern.Regex.String(), fp) + logger.Debugf("api key pattern %q did not match path %q", pattern.Regex.String(), fp) } return "", false diff --git a/pkg/apikey/apikey_test.go b/pkg/apikey/apikey_test.go index 14b4abc5..2a8d0b01 100644 --- a/pkg/apikey/apikey_test.go +++ b/pkg/apikey/apikey_test.go @@ -1,6 +1,7 @@ package apikey_test import ( + "context" "os" "path/filepath" "regexp" @@ -10,6 +11,7 @@ import ( "github.com/wakatime/wakatime-cli/pkg/apikey" "github.com/wakatime/wakatime-cli/pkg/heartbeat" + "github.com/wakatime/wakatime-cli/pkg/regex" "github.com/gandarez/go-realpath" "github.com/stretchr/testify/assert" @@ -30,13 +32,13 @@ func TestWithReplacing(t *testing.T) { MapPatterns: []apikey.MapPattern{ { APIKey: "00000000-0000-4000-8000-000000000001", - Regex: regexp.MustCompile(`.workdir.`), + Regex: regex.NewRegexpWrap(regexp.MustCompile(`.workdir.`)), }, }, } opt := apikey.WithReplacing(config) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { APIKey: "00000000-0000-4000-8000-000000000000", @@ -55,7 +57,7 @@ func TestWithReplacing(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{first, second}) + result, err := h(context.Background(), []heartbeat.Heartbeat{first, second}) require.NoError(t, err) assert.Equal(t, []heartbeat.Result{ @@ -75,15 +77,15 @@ func TestApiKey_MatchPattern(t *testing.T) { patterns := []apikey.MapPattern{ { APIKey: "00000000-0000-4000-8000-000000000000", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "otherfolder"))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "otherfolder")))), }, { APIKey: "00000000-0000-4000-8000-000000000001", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, `test([a-zA-Z]+)`))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, `test([a-zA-Z]+)`)))), }, } - result, ok := apikey.MatchPattern(rp, patterns) + result, ok := apikey.MatchPattern(context.Background(), rp, patterns) assert.True(t, ok) assert.Equal(t, "00000000-0000-4000-8000-000000000001", result) @@ -99,21 +101,21 @@ func TestApiKey_MatchPattern_NoMatch(t *testing.T) { patterns := []apikey.MapPattern{ { APIKey: "00000000-0000-4000-8000-000000000000", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "otherfolder"))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "otherfolder")))), }, { APIKey: "00000000-0000-4000-8000-000000000001", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "temp"))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "temp")))), }, } - _, ok := apikey.MatchPattern(rp, patterns) + _, ok := apikey.MatchPattern(context.Background(), rp, patterns) assert.False(t, ok) } func TestApiKey_MatchPattern_ZeroPatterns(t *testing.T) { - _, ok := apikey.MatchPattern("", []apikey.MapPattern{}) + _, ok := apikey.MatchPattern(context.Background(), "", []apikey.MapPattern{}) assert.False(t, ok) } diff --git a/pkg/backoff/backoff.go b/pkg/backoff/backoff.go index 56df4197..da80450d 100644 --- a/pkg/backoff/backoff.go +++ b/pkg/backoff/backoff.go @@ -1,6 +1,7 @@ package backoff import ( + "context" "errors" "fmt" "math" @@ -39,10 +40,11 @@ type Config struct { // a heartbeat when the api is unresponsive. func WithBackoff(config Config) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute heartbeat backoff algorithm") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute heartbeat backoff algorithm") - if shouldBackoff(config.Retries, config.At) { + if shouldBackoff(ctx, config.Retries, config.At) { if config.HasProxy { return nil, api.ErrBackoff{Err: errors.New("won't send heartbeat due to backoff with proxy")} } @@ -50,11 +52,11 @@ func WithBackoff(config Config) heartbeat.HandleOption { return nil, api.ErrBackoff{Err: errors.New("won't send heartbeat due to backoff without proxy")} } - results, err := next(hh) + results, err := next(ctx, hh) if err != nil { // error response, increment backoff - if updateErr := updateBackoffSettings(config.V, config.Retries+1, time.Now()); updateErr != nil { - log.Warnf("failed to update backoff settings: %s", updateErr) + if updateErr := updateBackoffSettings(ctx, config.V, config.Retries+1, time.Now()); updateErr != nil { + logger.Warnf("failed to update backoff settings: %s", updateErr) } return nil, err @@ -62,8 +64,8 @@ func WithBackoff(config Config) heartbeat.HandleOption { // success response, reset backoff if config.Retries > 0 || !config.At.IsZero() { - if resetErr := updateBackoffSettings(config.V, 0, time.Time{}); resetErr != nil { - log.Warnf("failed to reset backoff settings: %s", resetErr) + if resetErr := updateBackoffSettings(ctx, config.V, 0, time.Time{}); resetErr != nil { + logger.Warnf("failed to reset backoff settings: %s", resetErr) } } @@ -75,17 +77,19 @@ func WithBackoff(config Config) heartbeat.HandleOption { // shouldBackoff returns true if we should save heartbeats directly to offline // database and skip sending to API due to rate limiting from too many recent // networking errors. -func shouldBackoff(retries int, at time.Time) bool { +func shouldBackoff(ctx context.Context, retries int, at time.Time) bool { if retries < 1 || at.IsZero() { return false } + logger := log.Extract(ctx) + backoffSeconds := float64(factor) * math.Pow(2, float64(retries)) duration := time.Duration(backoffSeconds) * time.Second if backoffSeconds > maxBackoffSecs { - log.Debugf( + logger.Debugf( "exponential backoff tried %d times since %s, will reset because reached %s max backoff", retries, at.Format(ini.DateFormat), @@ -99,7 +103,7 @@ func shouldBackoff(retries int, at time.Time) bool { return false } - log.Debugf( + logger.Debugf( "exponential backoff tried %d times since %s, will retry again after %s", retries, at.Format(ini.DateFormat), @@ -109,8 +113,8 @@ func shouldBackoff(retries int, at time.Time) bool { return true } -func updateBackoffSettings(v *viper.Viper, retries int, at time.Time) error { - w, err := ini.NewWriter(v, ini.InternalFilePath) +func updateBackoffSettings(ctx context.Context, v *viper.Viper, retries int, at time.Time) error { + w, err := ini.NewWriter(ctx, v, ini.InternalFilePath) if err != nil { return fmt.Errorf("failed to parse config file: %s", err) } @@ -126,7 +130,7 @@ func updateBackoffSettings(v *viper.Viper, retries int, at time.Time) error { keyValue["backoff_at"] = "" } - if err := w.Write("internal", keyValue); err != nil { + if err := w.Write(ctx, "internal", keyValue); err != nil { return fmt.Errorf("failed to write to internal config file: %s", err) } diff --git a/pkg/backoff/backoff_internal_test.go b/pkg/backoff/backoff_internal_test.go index 5f003ca0..853e4fa7 100644 --- a/pkg/backoff/backoff_internal_test.go +++ b/pkg/backoff/backoff_internal_test.go @@ -1,6 +1,7 @@ package backoff import ( + "context" "os" "testing" "time" @@ -15,7 +16,7 @@ import ( func TestShouldBackoff(t *testing.T) { at := time.Now().Add(time.Second * -1) - should := shouldBackoff(1, at) + should := shouldBackoff(context.Background(), 1, at) assert.True(t, should) } @@ -23,7 +24,7 @@ func TestShouldBackoff(t *testing.T) { func TestShouldBackoff_AfterResetTime(t *testing.T) { at := time.Now().Add(time.Second * -1) - should := shouldBackoff(8, at) + should := shouldBackoff(context.Background(), 8, at) assert.False(t, should) } @@ -31,34 +32,35 @@ func TestShouldBackoff_AfterResetTime(t *testing.T) { func TestShouldBackoff_AfterResetTime_ZeroRetries(t *testing.T) { at := time.Now().Add(maxBackoffSecs + 1*time.Second) - should := shouldBackoff(0, at) + should := shouldBackoff(context.Background(), 0, at) assert.False(t, should) } func TestShouldBackoff_NegateBackoff(t *testing.T) { - should := shouldBackoff(0, time.Time{}) + should := shouldBackoff(context.Background(), 0, time.Time{}) assert.False(t, should) } func TestUpdateBackoffSettings(t *testing.T) { - v := viper.New() - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + ctx := context.Background() + + v := viper.New() v.Set("config", tmpFile.Name()) v.Set("internal-config", tmpFile.Name()) at := time.Now().Add(time.Second * -1) - err = updateBackoffSettings(v, 2, at) + err = updateBackoffSettings(ctx, v, 2, at) require.NoError(t, err) - writer, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + writer, err := ini.NewWriter(ctx, v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return tmpFile.Name(), nil }) @@ -71,20 +73,21 @@ func TestUpdateBackoffSettings(t *testing.T) { } func TestUpdateBackoffSettings_NotInBackoff(t *testing.T) { - v := viper.New() - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + ctx := context.Background() + + v := viper.New() v.Set("config", tmpFile.Name()) v.Set("internal-config", tmpFile.Name()) - err = updateBackoffSettings(v, 0, time.Time{}) + err = updateBackoffSettings(ctx, v, 0, time.Time{}) require.NoError(t, err) - writer, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + writer, err := ini.NewWriter(ctx, v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return tmpFile.Name(), nil }) diff --git a/pkg/backoff/backoff_test.go b/pkg/backoff/backoff_test.go index 9dad7f4a..29f22913 100644 --- a/pkg/backoff/backoff_test.go +++ b/pkg/backoff/backoff_test.go @@ -1,6 +1,7 @@ package backoff_test import ( + "context" "errors" "os" "testing" @@ -17,20 +18,19 @@ import ( ) func TestWithBackoff(t *testing.T) { - v := viper.New() - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + v := viper.New() v.Set("internal-config", tmpFile.Name()) opt := backoff.WithBackoff(backoff.Config{ V: v, }) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -38,7 +38,7 @@ func TestWithBackoff(t *testing.T) { }, nil }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(context.Background(), []heartbeat.Heartbeat{}) require.NoError(t, err) err = ini.ReadInConfig(v, tmpFile.Name()) @@ -50,13 +50,14 @@ func TestWithBackoff(t *testing.T) { } func TestWithBackoff_BeforeNextBackoff(t *testing.T) { - v := viper.New() - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + ctx := context.Background() + + v := viper.New() v.Set("internal-config", tmpFile.Name()) at := time.Now() @@ -68,11 +69,11 @@ func TestWithBackoff_BeforeNextBackoff(t *testing.T) { At: at, }) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{}, errors.New("error") }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(ctx, []heartbeat.Heartbeat{}) require.Error(t, err) assert.Equal(t, "error", err.Error()) @@ -91,7 +92,7 @@ func TestWithBackoff_BeforeNextBackoff(t *testing.T) { At: at.Add(time.Second * 15), }) - handle = opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle = opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -99,7 +100,7 @@ func TestWithBackoff_BeforeNextBackoff(t *testing.T) { }, nil }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(ctx, []heartbeat.Heartbeat{}) require.Error(t, err) assert.Equal(t, "won't send heartbeat due to backoff without proxy", err.Error()) @@ -121,7 +122,7 @@ func TestWithBackoff_BeforeNextBackoffWithProxy(t *testing.T) { HasProxy: true, }) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -129,7 +130,7 @@ func TestWithBackoff_BeforeNextBackoffWithProxy(t *testing.T) { }, nil }) - _, err := handle([]heartbeat.Heartbeat{}) + _, err := handle(context.Background(), []heartbeat.Heartbeat{}) require.Error(t, err) assert.Equal(t, "won't send heartbeat due to backoff with proxy", err.Error()) @@ -149,11 +150,11 @@ func TestWithBackoff_ApiError(t *testing.T) { V: v, }) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{}, errors.New("error") }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(context.Background(), []heartbeat.Heartbeat{}) require.Error(t, err) assert.Equal(t, "error", err.Error()) @@ -167,13 +168,12 @@ func TestWithBackoff_ApiError(t *testing.T) { } func TestWithBackoff_BackoffAndNotReset(t *testing.T) { - v := viper.New() - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + v := viper.New() v.Set("internal-config", tmpFile.Name()) opt := backoff.WithBackoff(backoff.Config{ @@ -182,7 +182,7 @@ func TestWithBackoff_BackoffAndNotReset(t *testing.T) { At: time.Now().Add(time.Second * -1), }) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -190,7 +190,7 @@ func TestWithBackoff_BackoffAndNotReset(t *testing.T) { }, nil }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(context.Background(), []heartbeat.Heartbeat{}) require.Error(t, err) var errbackoff api.ErrBackoff @@ -206,13 +206,14 @@ func TestWithBackoff_BackoffAndNotReset(t *testing.T) { } func TestWithBackoff_BackoffMaxReached(t *testing.T) { - v := viper.New() - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + ctx := context.Background() + + v := viper.New() v.Set("internal-config", tmpFile.Name()) // first, cause backoff to be set @@ -220,11 +221,11 @@ func TestWithBackoff_BackoffMaxReached(t *testing.T) { V: v, }) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{}, errors.New("error") }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(ctx, []heartbeat.Heartbeat{}) require.Error(t, err) err = ini.ReadInConfig(v, tmpFile.Name()) @@ -241,7 +242,7 @@ func TestWithBackoff_BackoffMaxReached(t *testing.T) { At: time.Now().Add(time.Second * -1), }) - handle = opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle = opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -249,7 +250,7 @@ func TestWithBackoff_BackoffMaxReached(t *testing.T) { }, nil }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(ctx, []heartbeat.Heartbeat{}) require.NoError(t, err) err = ini.ReadInConfig(v, tmpFile.Name()) @@ -261,13 +262,14 @@ func TestWithBackoff_BackoffMaxReached(t *testing.T) { } func TestWithBackoff_BackoffMaxReachedWithZeroRetries(t *testing.T) { - v := viper.New() - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + ctx := context.Background() + + v := viper.New() v.Set("internal-config", tmpFile.Name()) // first, cause backoff to be set @@ -275,11 +277,11 @@ func TestWithBackoff_BackoffMaxReachedWithZeroRetries(t *testing.T) { V: v, }) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{}, errors.New("error") }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(ctx, []heartbeat.Heartbeat{}) require.Error(t, err) err = ini.ReadInConfig(v, tmpFile.Name()) @@ -296,7 +298,7 @@ func TestWithBackoff_BackoffMaxReachedWithZeroRetries(t *testing.T) { At: time.Now().Add(time.Hour + 1*time.Second), }) - handle = opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle = opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -304,7 +306,7 @@ func TestWithBackoff_BackoffMaxReachedWithZeroRetries(t *testing.T) { }, nil }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(ctx, []heartbeat.Heartbeat{}) require.NoError(t, err) err = ini.ReadInConfig(v, tmpFile.Name()) @@ -316,24 +318,25 @@ func TestWithBackoff_BackoffMaxReachedWithZeroRetries(t *testing.T) { } func TestWithBackoff_ShouldRetry(t *testing.T) { - v := viper.New() - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + ctx := context.Background() + + v := viper.New() v.Set("internal-config", tmpFile.Name()) opt := backoff.WithBackoff(backoff.Config{ V: v, }) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{}, errors.New("error") }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(ctx, []heartbeat.Heartbeat{}) require.Error(t, err) assert.Equal(t, "error", err.Error()) @@ -354,7 +357,7 @@ func TestWithBackoff_ShouldRetry(t *testing.T) { At: at.Add(time.Second * -60), }) - handle = opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle = opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -362,7 +365,7 @@ func TestWithBackoff_ShouldRetry(t *testing.T) { }, nil }) - _, err = handle([]heartbeat.Heartbeat{}) + _, err = handle(ctx, []heartbeat.Heartbeat{}) require.NoError(t, err) err = ini.ReadInConfig(v, tmpFile.Name()) diff --git a/pkg/deps/c.go b/pkg/deps/c.go index 1be426e3..0aee3667 100644 --- a/pkg/deps/c.go +++ b/pkg/deps/c.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -34,7 +35,9 @@ type ParserC struct { } // Parse parses dependencies from C file content using the C lexer. -func (p *ParserC) Parse(filepath string) ([]string, error) { +func (p *ParserC) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -42,7 +45,7 @@ func (p *ParserC) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/c_test.go b/pkg/deps/c_test.go index 5358ef2d..cfe4164e 100644 --- a/pkg/deps/c_test.go +++ b/pkg/deps/c_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserC_Parse(t *testing.T) { parser := deps.ParserC{} - dependencies, err := parser.Parse("testdata/c.c") + dependencies, err := parser.Parse(context.Background(), "testdata/c.c") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/cpp.go b/pkg/deps/cpp.go index fd70057f..a9de8427 100644 --- a/pkg/deps/cpp.go +++ b/pkg/deps/cpp.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -34,7 +35,9 @@ type ParserCPP struct { } // Parse parses dependencies from C++ file content using the C lexer. -func (p *ParserCPP) Parse(filepath string) ([]string, error) { +func (p *ParserCPP) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -42,7 +45,7 @@ func (p *ParserCPP) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/cpp_test.go b/pkg/deps/cpp_test.go index ded41349..4da96e4f 100644 --- a/pkg/deps/cpp_test.go +++ b/pkg/deps/cpp_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -11,7 +12,7 @@ import ( func TestParserCPP_Parse(t *testing.T) { parser := deps.ParserCPP{} - dependencies, err := parser.Parse("testdata/cpp.cpp") + dependencies, err := parser.Parse(context.Background(), "testdata/cpp.cpp") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/csharp.go b/pkg/deps/csharp.go index 76c09fe9..9023de0d 100644 --- a/pkg/deps/csharp.go +++ b/pkg/deps/csharp.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -35,7 +36,9 @@ type ParserCSharp struct { } // Parse parses dependencies from C# file content using the chroma C# lexer. -func (p *ParserCSharp) Parse(filepath string) ([]string, error) { +func (p *ParserCSharp) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -43,7 +46,7 @@ func (p *ParserCSharp) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/csharp_test.go b/pkg/deps/csharp_test.go index aca0d453..bec69251 100644 --- a/pkg/deps/csharp_test.go +++ b/pkg/deps/csharp_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserCSharp_Parse(t *testing.T) { parser := deps.ParserCSharp{} - dependencies, err := parser.Parse("testdata/csharp.cs") + dependencies, err := parser.Parse(context.Background(), "testdata/csharp.cs") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/deps.go b/pkg/deps/deps.go index f307e242..097b2d1c 100644 --- a/pkg/deps/deps.go +++ b/pkg/deps/deps.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "github.com/wakatime/wakatime-cli/pkg/heartbeat" @@ -25,7 +26,7 @@ type Config struct { // DependencyParser is a dependency parser for a programming language. type DependencyParser interface { - Parse(filepath string) ([]string, error) + Parse(ctx context.Context, filepath string) ([]string, error) } // WithDetection initializes and returns a heartbeat handle option, which @@ -34,8 +35,9 @@ type DependencyParser interface { // local file if available. func WithDetection(c Config) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute dependency detection") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute dependency detection") for n, h := range hh { if h.EntityType != heartbeat.FileType { @@ -50,7 +52,7 @@ func WithDetection(c Config) heartbeat.HandleOption { continue } - if heartbeat.ShouldSanitize(h.Entity, c.FilePatterns) { + if heartbeat.ShouldSanitize(ctx, h.Entity, c.FilePatterns) { continue } @@ -62,25 +64,25 @@ func WithDetection(c Config) heartbeat.HandleOption { language, ok := heartbeat.ParseLanguage(*h.Language) if !ok { - log.Debugf("error parsing language of string %q", *h.Language) + logger.Debugf("error parsing language of string %q", *h.Language) } - dependencies, err := Detect(filepath, language) + dependencies, err := Detect(ctx, filepath, language) if err != nil { - log.Debugf("error detecting dependencies: %s", err) + logger.Debugf("error detecting dependencies: %s", err) continue } hh[n].Dependencies = dependencies } - return next(hh) + return next(ctx, hh) } } } // Detect parses the dependencies from a heartbeat file of a specific language. -func Detect(filepath string, language heartbeat.Language) ([]string, error) { +func Detect(ctx context.Context, filepath string, language heartbeat.Language) ([]string, error) { var parser DependencyParser switch language { @@ -126,24 +128,26 @@ func Detect(filepath string, language heartbeat.Language) ([]string, error) { parser = &ParserUnknown{} } - deps, err := parser.Parse(filepath) + deps, err := parser.Parse(ctx, filepath) if err != nil { return nil, fmt.Errorf("failed to parse dependencies: %s", err) } - return filterDependencies(deps), nil + return filterDependencies(ctx, deps), nil } -func filterDependencies(deps []string) []string { +func filterDependencies(ctx context.Context, deps []string) []string { var ( results []string unique = make(map[string]struct{}) ) + logger := log.Extract(ctx) + for _, d := range deps { // filter max size if len(results) >= maxDependenciesCount { - log.Debugf("max size of %d dependencies reached", maxDependenciesCount) + logger.Debugf("max size of %d dependencies reached", maxDependenciesCount) break } @@ -154,7 +158,7 @@ func filterDependencies(deps []string) []string { // filter dependencies off size if d == "" || len(d) > maxDependencyLength { - log.Debugf( + logger.Debugf( "dependency won't be sent because it's either empty or greater than %d characters: %s", maxDependencyLength, d, diff --git a/pkg/deps/deps_test.go b/pkg/deps/deps_test.go index d108a888..0e29b8f5 100644 --- a/pkg/deps/deps_test.go +++ b/pkg/deps/deps_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "regexp" "testing" @@ -15,7 +16,7 @@ import ( func TestWithDetection(t *testing.T) { opt := deps.WithDetection(deps.Config{}) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Dependencies: []string{ @@ -35,7 +36,7 @@ func TestWithDetection(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{{ + result, err := h(context.Background(), []heartbeat.Heartbeat{{ Entity: "testdata/golang_minimal.go", EntityType: heartbeat.FileType, Language: heartbeat.PointerTo("Go"), @@ -51,10 +52,10 @@ func TestWithDetection(t *testing.T) { func TestWithDetection_SkipSanitized(t *testing.T) { opt := deps.WithDetection(deps.Config{ - FilePatterns: []regex.Regex{regexp.MustCompile(".*")}, + FilePatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh[0].Dependencies, 0) return []heartbeat.Result{ @@ -64,7 +65,7 @@ func TestWithDetection_SkipSanitized(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{{ + result, err := h(context.Background(), []heartbeat.Heartbeat{{ Entity: "testdata/golang.go", EntityType: heartbeat.FileType, Language: heartbeat.PointerTo("Go"), @@ -81,7 +82,7 @@ func TestWithDetection_SkipSanitized(t *testing.T) { func TestWithDetection_LocalFile(t *testing.T) { opt := deps.WithDetection(deps.Config{}) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Dependencies: []string{ @@ -102,7 +103,7 @@ func TestWithDetection_LocalFile(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{{ + result, err := h(context.Background(), []heartbeat.Heartbeat{{ Entity: "testdata/golang.go", EntityType: heartbeat.FileType, Language: heartbeat.PointerTo("Go"), @@ -120,7 +121,7 @@ func TestWithDetection_LocalFile(t *testing.T) { func TestWithDetection_NonFileType(t *testing.T) { opt := deps.WithDetection(deps.Config{}) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Entity: "testdata/codefiles/golang.go", @@ -135,7 +136,7 @@ func TestWithDetection_NonFileType(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{{ + result, err := h(context.Background(), []heartbeat.Heartbeat{{ Entity: "testdata/codefiles/golang.go", EntityType: heartbeat.AppType, }}) @@ -149,6 +150,8 @@ func TestWithDetection_NonFileType(t *testing.T) { } func TestDetect(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Filepath string Language heartbeat.Language @@ -266,7 +269,7 @@ func TestDetect(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - deps, err := deps.Detect(test.Filepath, test.Language) + deps, err := deps.Detect(ctx, test.Filepath, test.Language) require.NoError(t, err) assert.Equal(t, test.Dependencies, deps) @@ -276,6 +279,7 @@ func TestDetect(t *testing.T) { func TestDetect_DuplicatesRemoved(t *testing.T) { deps, err := deps.Detect( + context.Background(), "testdata/golang_duplicate.go", heartbeat.LanguageGo, ) @@ -288,6 +292,7 @@ func TestDetect_DuplicatesRemoved(t *testing.T) { func TestDetect_LongDependenciesRemoved(t *testing.T) { deps, err := deps.Detect( + context.Background(), "testdata/python_with_long_import.py", heartbeat.LanguagePython, ) @@ -303,6 +308,7 @@ func TestDetect_LongDependenciesRemoved(t *testing.T) { func TestDetect_MaxDependenciesCountReached(t *testing.T) { deps, err := deps.Detect( + context.Background(), "testdata/python_with_many_imports.py", heartbeat.LanguagePython, ) @@ -313,6 +319,7 @@ func TestDetect_MaxDependenciesCountReached(t *testing.T) { func TestDetect_EmptyDependenciesRemoved(t *testing.T) { deps, err := deps.Detect( + context.Background(), "testdata/bower_empty_dependency.json", heartbeat.LanguageJSON, ) diff --git a/pkg/deps/elm.go b/pkg/deps/elm.go index 5ba1f0b4..dbf4e836 100644 --- a/pkg/deps/elm.go +++ b/pkg/deps/elm.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "strings" @@ -31,7 +32,9 @@ type ParserElm struct { } // Parse parses dependencies from Elm file content using the chroma Elm lexer. -func (p *ParserElm) Parse(filepath string) ([]string, error) { +func (p *ParserElm) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -39,7 +42,7 @@ func (p *ParserElm) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/elm_test.go b/pkg/deps/elm_test.go index e55992be..76e1e6ec 100644 --- a/pkg/deps/elm_test.go +++ b/pkg/deps/elm_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserElm_Parse(t *testing.T) { parser := deps.ParserElm{} - dependencies, err := parser.Parse("testdata/elm.elm") + dependencies, err := parser.Parse(context.Background(), "testdata/elm.elm") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/golang.go b/pkg/deps/golang.go index f6d8d94b..6d82a211 100644 --- a/pkg/deps/golang.go +++ b/pkg/deps/golang.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -34,7 +35,9 @@ type ParserGo struct { } // Parse parses dependencies from Golang file content using the chroma Golang lexer. -func (p *ParserGo) Parse(filepath string) ([]string, error) { +func (p *ParserGo) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -42,7 +45,7 @@ func (p *ParserGo) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/golang_test.go b/pkg/deps/golang_test.go index c59d1f21..0e45a373 100644 --- a/pkg/deps/golang_test.go +++ b/pkg/deps/golang_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserGo_Parse(t *testing.T) { parser := deps.ParserGo{} - dependencies, err := parser.Parse("testdata/golang.go") + dependencies, err := parser.Parse(context.Background(), "testdata/golang.go") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/haskell.go b/pkg/deps/haskell.go index 4457e5e1..aa799a43 100644 --- a/pkg/deps/haskell.go +++ b/pkg/deps/haskell.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "strings" @@ -31,7 +32,9 @@ type ParserHaskell struct { } // Parse parses dependencies from Haskell file content using the chroma Haskell lexer. -func (p *ParserHaskell) Parse(filepath string) ([]string, error) { +func (p *ParserHaskell) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -39,7 +42,7 @@ func (p *ParserHaskell) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/haskell_test.go b/pkg/deps/haskell_test.go index 1a5308fd..bf6abea6 100644 --- a/pkg/deps/haskell_test.go +++ b/pkg/deps/haskell_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserHaskell_Parse(t *testing.T) { parser := deps.ParserHaskell{} - dependencies, err := parser.Parse("testdata/haskell.hs") + dependencies, err := parser.Parse(context.Background(), "testdata/haskell.hs") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/haxe.go b/pkg/deps/haxe.go index d68148c8..51edbc12 100644 --- a/pkg/deps/haxe.go +++ b/pkg/deps/haxe.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -33,7 +34,9 @@ type ParserHaxe struct { } // Parse parses dependencies from Haxe file content using the chroma Haxe lexer. -func (p *ParserHaxe) Parse(filepath string) ([]string, error) { +func (p *ParserHaxe) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -41,7 +44,7 @@ func (p *ParserHaxe) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/haxe_test.go b/pkg/deps/haxe_test.go index c81ba528..8cfa6f57 100644 --- a/pkg/deps/haxe_test.go +++ b/pkg/deps/haxe_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserHaxe_Parse(t *testing.T) { parser := deps.ParserHaxe{} - dependencies, err := parser.Parse("testdata/haxe.hx") + dependencies, err := parser.Parse(context.Background(), "testdata/haxe.hx") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/html.go b/pkg/deps/html.go index 2d489c33..e69da9de 100644 --- a/pkg/deps/html.go +++ b/pkg/deps/html.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -35,7 +36,9 @@ type ParserHTML struct { } // Parse parses dependencies from HTML file content via ReadCloser using the chroma HTML lexer. -func (p *ParserHTML) Parse(filepath string) ([]string, error) { +func (p *ParserHTML) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -43,7 +46,7 @@ func (p *ParserHTML) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/html_test.go b/pkg/deps/html_test.go index 9eddf8c5..d025ca04 100644 --- a/pkg/deps/html_test.go +++ b/pkg/deps/html_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -10,6 +11,8 @@ import ( ) func TestParserHTML_Parse(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Filepath string Expected []string @@ -37,7 +40,7 @@ func TestParserHTML_Parse(t *testing.T) { t.Run(name, func(t *testing.T) { parser := deps.ParserHTML{} - dependencies, err := parser.Parse(test.Filepath) + dependencies, err := parser.Parse(ctx, test.Filepath) require.NoError(t, err) assert.Equal(t, test.Expected, dependencies) diff --git a/pkg/deps/java.go b/pkg/deps/java.go index dc1f13fe..c7f1b87d 100644 --- a/pkg/deps/java.go +++ b/pkg/deps/java.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -37,7 +38,9 @@ type ParserJava struct { } // Parse parses dependencies from Java file content using the chroma Java lexer. -func (p *ParserJava) Parse(filepath string) ([]string, error) { +func (p *ParserJava) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -45,7 +48,7 @@ func (p *ParserJava) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/java_test.go b/pkg/deps/java_test.go index a49cedf2..7d03e9f4 100644 --- a/pkg/deps/java_test.go +++ b/pkg/deps/java_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserJava_Parse(t *testing.T) { parser := deps.ParserJava{} - dependencies, err := parser.Parse("testdata/java.java") + dependencies, err := parser.Parse(context.Background(), "testdata/java.java") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/javascript.go b/pkg/deps/javascript.go index 53db0c45..d9241e4c 100644 --- a/pkg/deps/javascript.go +++ b/pkg/deps/javascript.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -34,7 +35,9 @@ type ParserJavaScript struct { } // Parse parses dependencies from JavaScript file content using the chroma JavaScript lexer. -func (p *ParserJavaScript) Parse(filepath string) ([]string, error) { +func (p *ParserJavaScript) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -42,7 +45,7 @@ func (p *ParserJavaScript) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/javascript_test.go b/pkg/deps/javascript_test.go index 2838ca59..8378fff1 100644 --- a/pkg/deps/javascript_test.go +++ b/pkg/deps/javascript_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -10,6 +11,8 @@ import ( ) func TestParserJavaScript_Parse(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Filepath string Expected []string @@ -69,7 +72,7 @@ func TestParserJavaScript_Parse(t *testing.T) { t.Run(name, func(t *testing.T) { parser := deps.ParserJavaScript{} - dependencies, err := parser.Parse(test.Filepath) + dependencies, err := parser.Parse(ctx, test.Filepath) require.NoError(t, err) assert.Equal(t, test.Expected, dependencies) diff --git a/pkg/deps/json.go b/pkg/deps/json.go index caed8ad8..50ea53b1 100644 --- a/pkg/deps/json.go +++ b/pkg/deps/json.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "path/filepath" @@ -43,7 +44,9 @@ type ParserJSON struct { } // Parse parses dependencies from JSON file content using the chroma JSON lexer. -func (p *ParserJSON) Parse(filepath string) ([]string, error) { +func (p *ParserJSON) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -51,7 +54,7 @@ func (p *ParserJSON) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/json_test.go b/pkg/deps/json_test.go index 0fc95c47..79b9d24b 100644 --- a/pkg/deps/json_test.go +++ b/pkg/deps/json_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -10,6 +11,8 @@ import ( ) func TestParserJSON_Parse(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Filepath string Expected []string @@ -49,7 +52,7 @@ func TestParserJSON_Parse(t *testing.T) { t.Run(name, func(t *testing.T) { parser := deps.ParserJSON{} - dependencies, err := parser.Parse(test.Filepath) + dependencies, err := parser.Parse(ctx, test.Filepath) require.NoError(t, err) assert.Equal(t, test.Expected, dependencies) diff --git a/pkg/deps/kotlin.go b/pkg/deps/kotlin.go index 942a368a..8506ef60 100644 --- a/pkg/deps/kotlin.go +++ b/pkg/deps/kotlin.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -34,7 +35,9 @@ type ParserKotlin struct { } // Parse parses dependencies from Kotlin file content using the chroma Kotlin lexer. -func (p *ParserKotlin) Parse(filepath string) ([]string, error) { +func (p *ParserKotlin) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -42,7 +45,7 @@ func (p *ParserKotlin) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/kotlin_test.go b/pkg/deps/kotlin_test.go index 13d704ef..cf3c2c86 100644 --- a/pkg/deps/kotlin_test.go +++ b/pkg/deps/kotlin_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserKotlin_Parse(t *testing.T) { parser := deps.ParserKotlin{} - dependencies, err := parser.Parse("testdata/kotlin.kt") + dependencies, err := parser.Parse(context.Background(), "testdata/kotlin.kt") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/objectivec.go b/pkg/deps/objectivec.go index 55348550..fc292e33 100644 --- a/pkg/deps/objectivec.go +++ b/pkg/deps/objectivec.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "strings" @@ -31,7 +32,9 @@ type ParserObjectiveC struct { } // Parse parses dependencies from Objective-C file content using the chroma Objective-C lexer. -func (p *ParserObjectiveC) Parse(filepath string) ([]string, error) { +func (p *ParserObjectiveC) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -39,7 +42,7 @@ func (p *ParserObjectiveC) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/objectivec_test.go b/pkg/deps/objectivec_test.go index f9ba0967..b8186fef 100644 --- a/pkg/deps/objectivec_test.go +++ b/pkg/deps/objectivec_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserObjectiveC_Parse(t *testing.T) { parser := deps.ParserObjectiveC{} - dependencies, err := parser.Parse("testdata/objective_c.m") + dependencies, err := parser.Parse(context.Background(), "testdata/objective_c.m") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/php.go b/pkg/deps/php.go index d1ccb168..a8426587 100644 --- a/pkg/deps/php.go +++ b/pkg/deps/php.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -40,7 +41,9 @@ type ParserPHP struct { } // Parse parses dependencies from PHP file content using the chroma PHP lexer. -func (p *ParserPHP) Parse(filepath string) ([]string, error) { +func (p *ParserPHP) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -48,7 +51,7 @@ func (p *ParserPHP) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/php_test.go b/pkg/deps/php_test.go index 9ed5cb2d..670de6d0 100644 --- a/pkg/deps/php_test.go +++ b/pkg/deps/php_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserPHP_Parse(t *testing.T) { parser := deps.ParserPHP{} - dependencies, err := parser.Parse("testdata/php.php") + dependencies, err := parser.Parse(context.Background(), "testdata/php.php") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/python.go b/pkg/deps/python.go index 9793f5c6..33434b4c 100644 --- a/pkg/deps/python.go +++ b/pkg/deps/python.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -37,7 +38,9 @@ type ParserPython struct { } // Parse parses dependencies from Python file content using the chroma Python lexer. -func (p *ParserPython) Parse(filepath string) ([]string, error) { +func (p *ParserPython) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -45,7 +48,7 @@ func (p *ParserPython) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/python_test.go b/pkg/deps/python_test.go index 4b9dbd8b..69931e1d 100644 --- a/pkg/deps/python_test.go +++ b/pkg/deps/python_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserPython_Parse(t *testing.T) { parser := deps.ParserPython{} - dependencies, err := parser.Parse("testdata/python.py") + dependencies, err := parser.Parse(context.Background(), "testdata/python.py") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/rust.go b/pkg/deps/rust.go index a11129fb..0161bd45 100644 --- a/pkg/deps/rust.go +++ b/pkg/deps/rust.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "strings" @@ -33,7 +34,9 @@ type ParserRust struct { } // Parse parses dependencies from Rust file content using the chroma Rust lexer. -func (p *ParserRust) Parse(filepath string) ([]string, error) { +func (p *ParserRust) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -41,7 +44,7 @@ func (p *ParserRust) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/rust_test.go b/pkg/deps/rust_test.go index 24d51a56..ee12e825 100644 --- a/pkg/deps/rust_test.go +++ b/pkg/deps/rust_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserRust_Parse(t *testing.T) { parser := deps.ParserRust{} - dependencies, err := parser.Parse("testdata/rust.rs") + dependencies, err := parser.Parse(context.Background(), "testdata/rust.rs") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/scala.go b/pkg/deps/scala.go index 7bcf4be9..a70288bb 100644 --- a/pkg/deps/scala.go +++ b/pkg/deps/scala.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "strings" @@ -31,7 +32,9 @@ type ParserScala struct { } // Parse parses dependencies from Scala file content using the chroma Scala lexer. -func (p *ParserScala) Parse(filepath string) ([]string, error) { +func (p *ParserScala) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -39,7 +42,7 @@ func (p *ParserScala) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/scala_test.go b/pkg/deps/scala_test.go index e0bd46ab..dc81fc4f 100644 --- a/pkg/deps/scala_test.go +++ b/pkg/deps/scala_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserScala_Parse(t *testing.T) { parser := deps.ParserScala{} - dependencies, err := parser.Parse("testdata/scala.scala") + dependencies, err := parser.Parse(context.Background(), "testdata/scala.scala") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/swift.go b/pkg/deps/swift.go index de160c11..79fcfc5b 100644 --- a/pkg/deps/swift.go +++ b/pkg/deps/swift.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -34,7 +35,9 @@ type ParserSwift struct { } // Parse parses dependencies from Swift file content using the chroma Swift lexer. -func (p *ParserSwift) Parse(filepath string) ([]string, error) { +func (p *ParserSwift) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -42,7 +45,7 @@ func (p *ParserSwift) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/swift_test.go b/pkg/deps/swift_test.go index f0d01da6..692a75fe 100644 --- a/pkg/deps/swift_test.go +++ b/pkg/deps/swift_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserSwift_Parse(t *testing.T) { parser := deps.ParserSwift{} - dependencies, err := parser.Parse("testdata/swift.swift") + dependencies, err := parser.Parse(context.Background(), "testdata/swift.swift") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/deps/unknown.go b/pkg/deps/unknown.go index 3aef7cb9..6bcb0fc5 100644 --- a/pkg/deps/unknown.go +++ b/pkg/deps/unknown.go @@ -1,6 +1,7 @@ package deps import ( + "context" "path/filepath" "strings" ) @@ -21,7 +22,7 @@ type ParserUnknown struct { } // Parse parses dependencies from any file content via ReadCloser using the chroma golang lexer. -func (p *ParserUnknown) Parse(fp string) ([]string, error) { +func (p *ParserUnknown) Parse(_ context.Context, fp string) ([]string, error) { p.init() defer p.init() diff --git a/pkg/deps/unknown_test.go b/pkg/deps/unknown_test.go index fefcba89..ea2b9565 100644 --- a/pkg/deps/unknown_test.go +++ b/pkg/deps/unknown_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -10,6 +11,8 @@ import ( ) func TestParserUnknown_Parse(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Filepath string Expected []string @@ -32,7 +35,7 @@ func TestParserUnknown_Parse(t *testing.T) { t.Run(name, func(t *testing.T) { parser := deps.ParserUnknown{} - dependencies, err := parser.Parse(test.Filepath) + dependencies, err := parser.Parse(ctx, test.Filepath) require.NoError(t, err) assert.Equal(t, test.Expected, dependencies) diff --git a/pkg/deps/vbnet.go b/pkg/deps/vbnet.go index 2454399c..ae7ff420 100644 --- a/pkg/deps/vbnet.go +++ b/pkg/deps/vbnet.go @@ -1,6 +1,7 @@ package deps import ( + "context" "fmt" "io" "regexp" @@ -35,7 +36,9 @@ type ParserVbNet struct { } // Parse parses dependencies from VB.Net file content using the chroma VB.Net lexer. -func (p *ParserVbNet) Parse(filepath string) ([]string, error) { +func (p *ParserVbNet) Parse(ctx context.Context, filepath string) ([]string, error) { + logger := log.Extract(ctx) + reader, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file %q: %s", filepath, err) @@ -43,7 +46,7 @@ func (p *ParserVbNet) Parse(filepath string) ([]string, error) { defer func() { if err := reader.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/deps/vbnet_test.go b/pkg/deps/vbnet_test.go index bbe2960c..a3e14c39 100644 --- a/pkg/deps/vbnet_test.go +++ b/pkg/deps/vbnet_test.go @@ -1,6 +1,7 @@ package deps_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/deps" @@ -12,7 +13,7 @@ import ( func TestParserVbNet_Parse(t *testing.T) { parser := deps.ParserVbNet{} - dependencies, err := parser.Parse("testdata/vbnet.vb") + dependencies, err := parser.Parse(context.Background(), "testdata/vbnet.vb") require.NoError(t, err) assert.Equal(t, []string{ diff --git a/pkg/fileexperts/fileexperts.go b/pkg/fileexperts/fileexperts.go index d6164581..c3653d99 100644 --- a/pkg/fileexperts/fileexperts.go +++ b/pkg/fileexperts/fileexperts.go @@ -1,6 +1,7 @@ package fileexperts import ( + "context" "encoding/json" "fmt" "strings" @@ -47,19 +48,19 @@ type ( // Caller calls wakatime api to get the file expert. type Caller interface { - FileExperts(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) + FileExperts(context.Context, []heartbeat.Heartbeat) ([]heartbeat.Result, error) } // NewHandle creates a new Handle, which acts like a processing pipeline, // with a caller eventually requesting the API. func NewHandle(caller Caller, opts ...heartbeat.HandleOption) heartbeat.Handle { - return func(heartbeats []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { var handle heartbeat.Handle = caller.FileExperts for i := len(opts) - 1; i >= 0; i-- { handle = opts[i](handle) } - return handle(heartbeats) + return handle(ctx, hh) } } diff --git a/pkg/fileexperts/validation.go b/pkg/fileexperts/validation.go index fb2a5163..3e50b18a 100644 --- a/pkg/fileexperts/validation.go +++ b/pkg/fileexperts/validation.go @@ -1,6 +1,8 @@ package fileexperts import ( + "context" + "github.com/wakatime/wakatime-cli/pkg/heartbeat" "github.com/wakatime/wakatime-cli/pkg/log" ) @@ -10,21 +12,22 @@ import ( // before sending it to the API. func WithValidation() heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute fileexperts validation") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute fileexperts validation") var filtered []heartbeat.Heartbeat for _, h := range hh { if !Validate(h) { - log.Debugf("missing required fields for fileexperts") + logger.Debugf("missing required fields for fileexperts") continue } filtered = append(filtered, h) } - return next(filtered) + return next(ctx, filtered) } } } diff --git a/pkg/fileexperts/validation_test.go b/pkg/fileexperts/validation_test.go index 071a777d..764ee799 100644 --- a/pkg/fileexperts/validation_test.go +++ b/pkg/fileexperts/validation_test.go @@ -1,6 +1,7 @@ package fileexperts_test import ( + "context" "testing" "github.com/wakatime/wakatime-cli/pkg/fileexperts" @@ -12,7 +13,7 @@ import ( func TestWithValidation(t *testing.T) { opt := fileexperts.WithValidation() - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Entity: "/path/to/file", @@ -28,7 +29,7 @@ func TestWithValidation(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{ + result, err := h(context.Background(), []heartbeat.Heartbeat{ { Entity: "/path/to/file", Project: heartbeat.PointerTo("wakatime"), diff --git a/pkg/filestats/filestats.go b/pkg/filestats/filestats.go index 916874cc..3f951109 100644 --- a/pkg/filestats/filestats.go +++ b/pkg/filestats/filestats.go @@ -2,6 +2,7 @@ package filestats import ( "bytes" + "context" "fmt" "io" "os" @@ -20,8 +21,9 @@ const maxFileSizeSupported = 2097152 // moment only the total number of lines in a file is detected. func WithDetection() heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute filestats detection") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute filestats detection") for n, h := range hh { if h.EntityType != heartbeat.FileType { @@ -47,12 +49,12 @@ func WithDetection() heartbeat.HandleOption { fileInfo, err := os.Stat(filepath) if err != nil { - log.Warnf("failed to retrieve file stats of file %q: %s", filepath, err) + logger.Warnf("failed to retrieve file stats of file %q: %s", filepath, err) continue } if fileInfo.Size() > maxFileSizeSupported { - log.Debugf( + logger.Debugf( "file %q exceeds max file size of %d bytes. Lines won't be counted", h.Entity, maxFileSizeSupported, @@ -61,21 +63,23 @@ func WithDetection() heartbeat.HandleOption { continue } - lines, err := countLineNumbers(filepath) + lines, err := countLineNumbers(ctx, filepath) if err != nil { - log.Warnf("failed to detect the total number of lines in file %q: %s", filepath, err) + logger.Warnf("failed to detect the total number of lines in file %q: %s", filepath, err) continue } hh[n].Lines = heartbeat.PointerTo(lines) } - return next(hh) + return next(ctx, hh) } } } -func countLineNumbers(filepath string) (int, error) { +func countLineNumbers(ctx context.Context, filepath string) (int, error) { + logger := log.Extract(ctx) + f, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return 0, fmt.Errorf("failed to open file: %s", err) @@ -83,7 +87,7 @@ func countLineNumbers(filepath string) (int, error) { defer func() { if err := f.Close(); err != nil { - log.Debugf("failed to close file: %s", err) + logger.Debugf("failed to close file: %s", err) } }() diff --git a/pkg/filestats/filestats_test.go b/pkg/filestats/filestats_test.go index 54192d7c..f52dfd5f 100644 --- a/pkg/filestats/filestats_test.go +++ b/pkg/filestats/filestats_test.go @@ -2,6 +2,7 @@ package filestats_test import ( "bytes" + "context" "os" "testing" @@ -14,7 +15,7 @@ import ( func TestWithDetection(t *testing.T) { opt := filestats.WithDetection() - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh, 2) assert.Contains(t, hh, heartbeat.Heartbeat{ EntityType: heartbeat.FileType, @@ -34,7 +35,7 @@ func TestWithDetection(t *testing.T) { }, nil }) - result, err := handle([]heartbeat.Heartbeat{ + result, err := handle(context.Background(), []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: "testdata/first.txt", @@ -55,7 +56,7 @@ func TestWithDetection(t *testing.T) { func TestWithDetection_RemoteFile(t *testing.T) { opt := filestats.WithDetection() - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh, 1) assert.Contains(t, hh, heartbeat.Heartbeat{ EntityType: heartbeat.FileType, @@ -69,7 +70,7 @@ func TestWithDetection_RemoteFile(t *testing.T) { }, nil }) - result, err := handle([]heartbeat.Heartbeat{ + result, err := handle(context.Background(), []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: "ssh://192.168.1.1/path/to/remote/main.go", @@ -95,7 +96,7 @@ func TestWithDetection_MaxFileSizeExceeded(t *testing.T) { require.NoError(t, err) opt := filestats.WithDetection() - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, hh, []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, @@ -107,7 +108,7 @@ func TestWithDetection_MaxFileSizeExceeded(t *testing.T) { return []heartbeat.Result{}, nil }) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(context.Background(), []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: f.Name(), diff --git a/pkg/filter/filter.go b/pkg/filter/filter.go index cf6eb8a1..7dfc6b6e 100644 --- a/pkg/filter/filter.go +++ b/pkg/filter/filter.go @@ -1,6 +1,7 @@ package filter import ( + "context" "fmt" "os" @@ -22,15 +23,16 @@ type Config struct { // the provided configurations. func WithFiltering(config Config) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute heartbeat filtering") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute heartbeat filtering") var filtered []heartbeat.Heartbeat for _, h := range hh { - err := Filter(h, config) + err := Filter(ctx, h, config) if err != nil { - log.Debugf(err.Error()) + logger.Debugln(err.Error()) continue } @@ -38,7 +40,7 @@ func WithFiltering(config Config) heartbeat.HandleOption { filtered = append(filtered, h) } - return next(filtered) + return next(ctx, filtered) } } } @@ -47,26 +49,28 @@ func WithFiltering(config Config) heartbeat.HandleOption { // can be used to abort execution if all heartbeats were filtered and the list is empty. func WithLengthValidator() heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + if len(hh) == 0 { - log.Debugln("no heartbeats left after filtering. abort heartbeat handling.") + logger.Debugln("no heartbeats left after filtering. abort heartbeat handling.") return []heartbeat.Result{}, nil } - return next(hh) + return next(ctx, hh) } } } // Filter determines, following the passed in configurations, if a heartbeat // should be skipped. -func Filter(h heartbeat.Heartbeat, config Config) error { +func Filter(ctx context.Context, h heartbeat.Heartbeat, config Config) error { // filter by pattern - if err := filterByPattern(h.Entity, config.Include, config.Exclude); err != nil { + if err := filterByPattern(ctx, h.Entity, config.Include, config.Exclude); err != nil { return fmt.Errorf("filter by pattern: %s", err) } - err := filterFileEntity(h, config) + err := filterFileEntity(ctx, h, config) if err != nil { return fmt.Errorf("filter file: %s", err) } @@ -77,21 +81,21 @@ func Filter(h heartbeat.Heartbeat, config Config) error { // filterByPattern determines if a heartbeat should be skipped by checking an // entity against include and exclude patterns. Include will override exclude. // Returns Err to signal to the caller to skip the heartbeat. -func filterByPattern(entity string, include, exclude []regex.Regex) error { +func filterByPattern(ctx context.Context, entity string, include, exclude []regex.Regex) error { if entity == "" { return nil } // filter by include pattern for _, pattern := range include { - if pattern.MatchString(entity) { + if pattern.MatchString(ctx, entity) { return nil } } // filter by exclude pattern for _, pattern := range exclude { - if pattern.MatchString(entity) { + if pattern.MatchString(ctx, entity) { return fmt.Errorf("skipping because matches exclude pattern %q", pattern.String()) } } @@ -103,7 +107,7 @@ func filterByPattern(entity string, include, exclude []regex.Regex) error { // the existence of the passed in filepath, and optionally by checking if a // wakatime project file can be detected in the filepath directory tree. // Returns an error to signal to the caller to skip the heartbeat. -func filterFileEntity(h heartbeat.Heartbeat, config Config) error { +func filterFileEntity(ctx context.Context, h heartbeat.Heartbeat, config Config) error { if h.EntityType != heartbeat.FileType { return nil } @@ -128,7 +132,7 @@ func filterFileEntity(h heartbeat.Heartbeat, config Config) error { // when including only with project file, skip files when the project doesn't have a .wakatime-project file if config.IncludeOnlyWithProjectFile { - _, ok := project.FindFileOrDirectory(entity, project.WakaTimeProjectFile) + _, ok := project.FindFileOrDirectory(ctx, entity, project.WakaTimeProjectFile) if !ok { return fmt.Errorf("skipping because missing .wakatime-project file in parent path") } diff --git a/pkg/filter/filter_test.go b/pkg/filter/filter_test.go index 82997073..0c0d92a3 100644 --- a/pkg/filter/filter_test.go +++ b/pkg/filter/filter_test.go @@ -1,6 +1,7 @@ package filter_test import ( + "context" "errors" "os" "path/filepath" @@ -27,7 +28,7 @@ func TestWithFiltering(t *testing.T) { second.Time++ opt := filter.WithFiltering(filter.Config{}) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo("heartbeat"), @@ -53,7 +54,7 @@ func TestWithFiltering(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{first, second}) + result, err := h(context.Background(), []heartbeat.Heartbeat{first, second}) require.NoError(t, err) assert.Equal(t, []heartbeat.Result{ @@ -65,11 +66,11 @@ func TestWithFiltering(t *testing.T) { func TestWithLengthValidator(t *testing.T) { opt := filter.WithLengthValidator() - h := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{}, errors.New("this should never be called") }) - result, err := h([]heartbeat.Heartbeat{}) + result, err := h(context.Background(), []heartbeat.Heartbeat{}) require.NoError(t, err) assert.Equal(t, result, []heartbeat.Result{}) @@ -84,7 +85,7 @@ func TestFilter(t *testing.T) { h := testHeartbeat() h.Entity = tmpFile.Name() - err = filter.Filter(h, filter.Config{}) + err = filter.Filter(context.Background(), h, filter.Config{}) require.NoError(t, err) } @@ -93,7 +94,7 @@ func TestFilter_NonFileTypeEmptyEntity(t *testing.T) { h.Entity = "" h.EntityType = heartbeat.AppType - err := filter.Filter(h, filter.Config{}) + err := filter.Filter(context.Background(), h, filter.Config{}) require.NoError(t, err) } @@ -103,7 +104,7 @@ func TestFilter_IsUnsavedEntity(t *testing.T) { h.EntityType = heartbeat.FileType h.IsUnsavedEntity = true - err := filter.Filter(h, filter.Config{}) + err := filter.Filter(context.Background(), h, filter.Config{}) require.NoError(t, err) } @@ -116,7 +117,7 @@ func TestFilter_IncludeMatchOverwritesExcludeMatch(t *testing.T) { h := testHeartbeat() h.Entity = tmpFile.Name() - err = filter.Filter(h, filter.Config{ + err = filter.Filter(context.Background(), h, filter.Config{ Exclude: []regex.Regex{ regex.MustCompile(".*main.go$"), }, @@ -136,7 +137,7 @@ func TestFilter_ErrMatchesExcludePattern(t *testing.T) { h := testHeartbeat() h.Entity = tmpFile.Name() - err = filter.Filter(h, filter.Config{ + err = filter.Filter(context.Background(), h, filter.Config{ Exclude: []regex.Regex{ regex.MustCompile("^.*exclude-this-file.*$"), }, @@ -148,7 +149,7 @@ func TestFilter_ErrMatchesExcludePattern(t *testing.T) { func TestFilter_ErrNonExistingFile(t *testing.T) { h := testHeartbeat() - err := filter.Filter(h, filter.Config{}) + err := filter.Filter(context.Background(), h, filter.Config{}) assert.EqualError(t, err, "filter file: skipping because of non-existing file \"/tmp/main.go\"") } @@ -169,7 +170,7 @@ func TestFilter_ExistingProjectFile(t *testing.T) { h := testHeartbeat() h.Entity = tmpFile.Name() - err = filter.Filter(h, filter.Config{ + err = filter.Filter(context.Background(), h, filter.Config{ IncludeOnlyWithProjectFile: true, }) require.NoError(t, err) @@ -180,7 +181,7 @@ func TestFilter_RemoteFileSkipsFiltering(t *testing.T) { h.LocalFile = h.Entity h.Entity = "ssh://wakatime:1234@192.168.1.1/path/to/remote/main.go" - err := filter.Filter(h, filter.Config{}) + err := filter.Filter(context.Background(), h, filter.Config{}) require.NoError(t, err) } @@ -193,7 +194,7 @@ func TestFilter_ErrNonExistingProjectFile(t *testing.T) { h := testHeartbeat() h.Entity = tmpFile.Name() - err = filter.Filter(h, filter.Config{ + err = filter.Filter(context.Background(), h, filter.Config{ IncludeOnlyWithProjectFile: true, }) diff --git a/pkg/heartbeat/entity_modifier_internal_test.go b/pkg/heartbeat/entity_modifier_internal_test.go index a528d652..ebfde3eb 100644 --- a/pkg/heartbeat/entity_modifier_internal_test.go +++ b/pkg/heartbeat/entity_modifier_internal_test.go @@ -1,6 +1,7 @@ package heartbeat import ( + "context" "os" "path/filepath" "testing" @@ -10,6 +11,8 @@ import ( ) func TestIsXCodePlayground(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Dir string Expected bool @@ -34,7 +37,7 @@ func TestIsXCodePlayground(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - ret := isXCodePlayground(test.Dir) + ret := isXCodePlayground(ctx, test.Dir) assert.Equal(t, test.Expected, ret) }) @@ -42,6 +45,8 @@ func TestIsXCodePlayground(t *testing.T) { } func TestIsXCodeProject(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Dir string Expected bool @@ -58,7 +63,7 @@ func TestIsXCodeProject(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - ret := isXCodeProject(test.Dir) + ret := isXCodeProject(ctx, test.Dir) assert.Equal(t, test.Expected, ret) }) diff --git a/pkg/heartbeat/entity_modify.go b/pkg/heartbeat/entity_modify.go index f25a7d53..a974fe92 100644 --- a/pkg/heartbeat/entity_modify.go +++ b/pkg/heartbeat/entity_modify.go @@ -1,6 +1,7 @@ package heartbeat import ( + "context" "path/filepath" "strings" @@ -11,39 +12,41 @@ import ( // can be used in a heartbeat processing pipeline to change an entity path. func WithEntityModifier() HandleOption { return func(next Handle) Handle { - return func(hh []Heartbeat) ([]Result, error) { - log.Debugln("execute heartbeat entity modifier") + return func(ctx context.Context, hh []Heartbeat) ([]Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute heartbeat entity modifier") for n, h := range hh { // Support XCode playgrounds - if h.EntityType == FileType && isXCodePlayground(h.Entity) { + if h.EntityType == FileType && isXCodePlayground(ctx, h.Entity) { hh[n].Entity = filepath.Join(h.Entity, "Contents.swift") } + // Support XCode projects - if h.EntityType == FileType && isXCodeProject(h.Entity) { + if h.EntityType == FileType && isXCodeProject(ctx, h.Entity) { hh[n].Entity = filepath.Join(h.Entity, "project.pbxproj") } } - return next(hh) + return next(ctx, hh) } } } -func isXCodePlayground(fp string) bool { +func isXCodePlayground(ctx context.Context, fp string) bool { if !(strings.HasSuffix(fp, ".playground") || strings.HasSuffix(fp, ".xcplayground") || strings.HasSuffix(fp, ".xcplaygroundpage")) { return false } - return isDir(fp) + return isDir(ctx, fp) } -func isXCodeProject(fp string) bool { +func isXCodeProject(ctx context.Context, fp string) bool { if !(strings.HasSuffix(fp, ".xcodeproj")) { return false } - return isDir(fp) + return isDir(ctx, fp) } diff --git a/pkg/heartbeat/entity_modify_test.go b/pkg/heartbeat/entity_modify_test.go index 32be3719..72763165 100644 --- a/pkg/heartbeat/entity_modify_test.go +++ b/pkg/heartbeat/entity_modify_test.go @@ -1,6 +1,7 @@ package heartbeat_test import ( + "context" "os" "path/filepath" "testing" @@ -19,7 +20,7 @@ func TestWithEntityModifier_XCodePlayground(t *testing.T) { opt := heartbeat.WithEntityModifier() - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Entity: filepath.Join(tmpDir, "wakatime.playground", "Contents.swift"), @@ -34,7 +35,7 @@ func TestWithEntityModifier_XCodePlayground(t *testing.T) { }, nil }) - result, err := handle([]heartbeat.Heartbeat{ + result, err := handle(context.Background(), []heartbeat.Heartbeat{ { Entity: filepath.Join(tmpDir, "wakatime.playground"), EntityType: heartbeat.FileType, diff --git a/pkg/heartbeat/format.go b/pkg/heartbeat/format.go index 8e9aa8b3..53ccd495 100644 --- a/pkg/heartbeat/format.go +++ b/pkg/heartbeat/format.go @@ -1,6 +1,7 @@ package heartbeat import ( + "context" "path/filepath" "runtime" @@ -14,8 +15,9 @@ import ( // can be used in a heartbeat processing pipeline to format entity's filepath. func WithFormatting() HandleOption { return func(next Handle) Handle { - return func(hh []Heartbeat) ([]Result, error) { - log.Debugln("execute heartbeat filepath formatting") + return func(ctx context.Context, hh []Heartbeat) ([]Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute heartbeat filepath formatting") for n, h := range hh { if h.EntityType != FileType { @@ -26,31 +28,33 @@ func WithFormatting() HandleOption { continue } - hh[n] = Format(h) + hh[n] = Format(ctx, h) } - return next(hh) + return next(ctx, hh) } } } // Format accepts a heartbeat formats it's filepath and returns the formatted version. -func Format(h Heartbeat) Heartbeat { +func Format(ctx context.Context, h Heartbeat) Heartbeat { if !h.IsUnsavedEntity && (runtime.GOOS != "windows" || !windows.IsWindowsNetworkMount(h.Entity)) { - formatLinuxFilePath(&h) + formatLinuxFilePath(ctx, &h) } if runtime.GOOS == "windows" { - formatWindowsFilePath(&h) + formatWindowsFilePath(ctx, &h) } return h } -func formatLinuxFilePath(h *Heartbeat) { +func formatLinuxFilePath(ctx context.Context, h *Heartbeat) { + logger := log.Extract(ctx) + formatted, err := filepath.Abs(h.Entity) if err != nil { - log.Debugf("failed to resolve absolute path for %q: %s", h.Entity, err) + logger.Debugf("failed to resolve absolute path for %q: %s", h.Entity, err) return } @@ -59,7 +63,7 @@ func formatLinuxFilePath(h *Heartbeat) { // evaluate any symlinks formatted, err = realpath.Realpath(h.Entity) if err != nil { - log.Debugf("failed to resolve real path for %q: %s", h.Entity, err) + logger.Debugf("failed to resolve real path for %q: %s", h.Entity, err) return } @@ -68,7 +72,7 @@ func formatLinuxFilePath(h *Heartbeat) { if h.ProjectPathOverride != "" { formatted, err = filepath.Abs(h.ProjectPathOverride) if err != nil { - log.Debugf("failed to resolve absolute path for %q: %s", h.ProjectPathOverride, err) + logger.Debugf("failed to resolve absolute path for %q: %s", h.ProjectPathOverride, err) return } @@ -77,7 +81,7 @@ func formatLinuxFilePath(h *Heartbeat) { // evaluate any symlinks formatted, err = realpath.Realpath(h.ProjectPathOverride) if err != nil { - log.Debugf("failed to resolve real path for %q: %s", h.ProjectPathOverride, err) + logger.Debugf("failed to resolve real path for %q: %s", h.ProjectPathOverride, err) return } @@ -85,7 +89,9 @@ func formatLinuxFilePath(h *Heartbeat) { } } -func formatWindowsFilePath(h *Heartbeat) { +func formatWindowsFilePath(ctx context.Context, h *Heartbeat) { + logger := log.Extract(ctx) + h.Entity = windows.FormatFilePath(h.Entity) if !h.IsUnsavedEntity && !windows.IsWindowsNetworkMount(h.Entity) { @@ -93,7 +99,7 @@ func formatWindowsFilePath(h *Heartbeat) { h.LocalFile, err = windows.FormatLocalFilePath(h.LocalFile, h.Entity) if err != nil { - log.Debugf("failed to format local file path: %s", err) + logger.Debugf("failed to format local file path: %s", err) } } diff --git a/pkg/heartbeat/format_test.go b/pkg/heartbeat/format_test.go index f1a3f33f..2d87f31f 100644 --- a/pkg/heartbeat/format_test.go +++ b/pkg/heartbeat/format_test.go @@ -1,6 +1,7 @@ package heartbeat_test import ( + "context" "path/filepath" "runtime" "testing" @@ -16,7 +17,7 @@ import ( func TestWithFormatting(t *testing.T) { opt := heartbeat.WithFormatting() - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { entity, err := filepath.Abs(hh[0].Entity) require.NoError(t, err) @@ -40,7 +41,7 @@ func TestWithFormatting(t *testing.T) { }, nil }) - result, err := handle([]heartbeat.Heartbeat{{ + result, err := handle(context.Background(), []heartbeat.Heartbeat{{ Entity: "testdata/main.go", EntityType: heartbeat.FileType, }}) @@ -63,7 +64,7 @@ func TestFormat_NetworkMount(t *testing.T) { EntityType: heartbeat.FileType, } - r := heartbeat.Format(h) + r := heartbeat.Format(context.Background(), h) assert.Equal(t, heartbeat.Heartbeat{ Entity: `\\192.168.1.1/apilibrary.sl`, diff --git a/pkg/heartbeat/heartbeat.go b/pkg/heartbeat/heartbeat.go index 32918b5c..14c16b1a 100644 --- a/pkg/heartbeat/heartbeat.go +++ b/pkg/heartbeat/heartbeat.go @@ -1,6 +1,7 @@ package heartbeat import ( + "context" "fmt" "os" "regexp" @@ -154,11 +155,11 @@ type Result struct { // Sender sends heartbeats to the wakatime api. type Sender interface { - SendHeartbeats(hh []Heartbeat) ([]Result, error) + SendHeartbeats(context.Context, []Heartbeat) ([]Result, error) } // Handle does processing of heartbeats. -type Handle func(hh []Heartbeat) ([]Result, error) +type Handle func(context.Context, []Heartbeat) ([]Result, error) // HandleOption is a function, which allows chaining multiple Handles. type HandleOption func(next Handle) Handle @@ -166,22 +167,24 @@ type HandleOption func(next Handle) Handle // NewHandle creates a new Handle, which acts like a processing pipeline, // with a sender eventually sending the heartbeats. func NewHandle(sender Sender, opts ...HandleOption) Handle { - return func(heartbeats []Heartbeat) ([]Result, error) { + return func(ctx context.Context, hh []Heartbeat) ([]Result, error) { var handle Handle = sender.SendHeartbeats for i := len(opts) - 1; i >= 0; i-- { handle = opts[i](handle) } - return handle(heartbeats) + return handle(ctx, hh) } } // UserAgent generates a user agent from various system infos, including a // a passed in value for plugin. -func UserAgent(plugin string) string { +func UserAgent(ctx context.Context, plugin string) string { + logger := log.Extract(ctx) + info, err := goInfo.GetInfo() if err != nil { - log.Debugf("goInfo.GetInfo error: %s", err) + logger.Debugf("goInfo.GetInfo error: %s", err) } if plugin == "" { @@ -191,7 +194,7 @@ func UserAgent(plugin string) string { return fmt.Sprintf( "wakatime/%s (%s-%s-%s) %s %s", version.Version, - strings.TrimSpace(system.OSName()), + strings.TrimSpace(system.OSName(ctx)), strings.TrimSpace(info.Core), strings.TrimSpace(info.Platform), strings.TrimSpace(runtime.Version()), @@ -204,10 +207,12 @@ func PointerTo[t bool | int | string](v t) *t { return &v } -func isDir(filepath string) bool { +func isDir(ctx context.Context, filepath string) bool { + logger := log.Extract(ctx) + info, err := os.Stat(filepath) if err != nil { - log.Warnf("failed to stat filepath %q: %s", filepath, err) + logger.Warnf("failed to stat filepath %q: %s", filepath, err) return false } diff --git a/pkg/heartbeat/heartbeat_test.go b/pkg/heartbeat/heartbeat_test.go index 135972a4..ba7b6376 100644 --- a/pkg/heartbeat/heartbeat_test.go +++ b/pkg/heartbeat/heartbeat_test.go @@ -1,6 +1,7 @@ package heartbeat_test import ( + "context" "encoding/json" "fmt" "io" @@ -146,7 +147,7 @@ func TestHeartbeat_JSON_NilFields(t *testing.T) { func TestNewHandle(t *testing.T) { sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo("test"), @@ -168,18 +169,18 @@ func TestNewHandle(t *testing.T) { opts := []heartbeat.HandleOption{ func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { for i := range hh { hh[i].Branch = heartbeat.PointerTo("test") } - return next(hh) + return next(ctx, hh) } }, } handle := heartbeat.NewHandle(&sender, opts...) - _, err := handle([]heartbeat.Heartbeat{ + _, err := handle(context.Background(), []heartbeat.Heartbeat{ { Category: heartbeat.CodingCategory, Entity: "/tmp/main.go", @@ -204,7 +205,7 @@ func TestUserAgentUnknownPlugin(t *testing.T) { runtime.Version(), ) - assert.Equal(t, expected, heartbeat.UserAgent("")) + assert.Equal(t, expected, heartbeat.UserAgent(context.Background(), "")) } func TestUserAgent(t *testing.T) { @@ -220,7 +221,7 @@ func TestUserAgent(t *testing.T) { runtime.Version(), ) - assert.Equal(t, expected, heartbeat.UserAgent("testplugin")) + assert.Equal(t, expected, heartbeat.UserAgent(context.Background(), "testplugin")) } func TestRemoteAddressRegex(t *testing.T) { @@ -260,11 +261,11 @@ func TestRemoteAddressRegex(t *testing.T) { } type mockSender struct { - SendHeartbeatsFn func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) + SendHeartbeatsFn func(context.Context, []heartbeat.Heartbeat) ([]heartbeat.Result, error) SendHeartbeatsFnInvoked bool } -func (m *mockSender) SendHeartbeats(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { +func (m *mockSender) SendHeartbeats(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { m.SendHeartbeatsFnInvoked = true - return m.SendHeartbeatsFn(hh) + return m.SendHeartbeatsFn(ctx, hh) } diff --git a/pkg/heartbeat/sanitize.go b/pkg/heartbeat/sanitize.go index 3bfdb966..5ff7c342 100644 --- a/pkg/heartbeat/sanitize.go +++ b/pkg/heartbeat/sanitize.go @@ -1,6 +1,7 @@ package heartbeat import ( + "context" "path/filepath" "strings" @@ -26,27 +27,28 @@ type SanitizeConfig struct { // can be used in a heartbeat processing pipeline to hide sensitive data. func WithSanitization(config SanitizeConfig) HandleOption { return func(next Handle) Handle { - return func(hh []Heartbeat) ([]Result, error) { - log.Debugln("execute heartbeat sanitization") + return func(ctx context.Context, hh []Heartbeat) ([]Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute heartbeat sanitization") for n, h := range hh { - hh[n] = Sanitize(h, config) + hh[n] = Sanitize(ctx, h, config) } - return next(hh) + return next(ctx, hh) } } } // Sanitize accepts a heartbeat sanitizes it's sensitive data following passed // in configuration and returns the sanitized version. On empty config will do nothing. -func Sanitize(h Heartbeat, config SanitizeConfig) Heartbeat { +func Sanitize(ctx context.Context, h Heartbeat, config SanitizeConfig) Heartbeat { if len(h.Dependencies) == 0 { h.Dependencies = nil } switch { - case ShouldSanitize(h.Entity, config.FilePatterns): + case ShouldSanitize(ctx, h.Entity, config.FilePatterns): if h.EntityType == FileType { h.Entity = "HIDDEN" + filepath.Ext(h.Entity) } else { @@ -55,15 +57,15 @@ func Sanitize(h Heartbeat, config SanitizeConfig) Heartbeat { h = sanitizeMetaData(h) - if h.Branch != nil && (len(config.BranchPatterns) == 0 || ShouldSanitize(*h.Branch, config.BranchPatterns)) { + if h.Branch != nil && (len(config.BranchPatterns) == 0 || ShouldSanitize(ctx, *h.Branch, config.BranchPatterns)) { h.Branch = nil } - case h.Project != nil && ShouldSanitize(*h.Project, config.ProjectPatterns): + case h.Project != nil && ShouldSanitize(ctx, *h.Project, config.ProjectPatterns): h = sanitizeMetaData(h) - if h.Branch != nil && (len(config.BranchPatterns) == 0 || ShouldSanitize(*h.Branch, config.BranchPatterns)) { + if h.Branch != nil && (len(config.BranchPatterns) == 0 || ShouldSanitize(ctx, *h.Branch, config.BranchPatterns)) { h.Branch = nil } - case h.Branch != nil && ShouldSanitize(*h.Branch, config.BranchPatterns): + case h.Branch != nil && ShouldSanitize(ctx, *h.Branch, config.BranchPatterns): h.Branch = nil } @@ -142,9 +144,9 @@ func sanitizeMetaData(h Heartbeat) Heartbeat { // ShouldSanitize checks a subject (entity, project, branch) of a heartbeat and // checks it against the passed in regex patterns to determine, if this heartbeat // should be sanitized. -func ShouldSanitize(subject string, patterns []regex.Regex) bool { +func ShouldSanitize(ctx context.Context, subject string, patterns []regex.Regex) bool { for _, p := range patterns { - if p.MatchString(subject) { + if p.MatchString(ctx, subject) { return true } } diff --git a/pkg/heartbeat/sanitize_test.go b/pkg/heartbeat/sanitize_test.go index 6a0eae6d..28e17057 100644 --- a/pkg/heartbeat/sanitize_test.go +++ b/pkg/heartbeat/sanitize_test.go @@ -1,6 +1,7 @@ package heartbeat_test import ( + "context" "regexp" "testing" @@ -13,10 +14,10 @@ import ( func TestWithSanitization_ObfuscateFile(t *testing.T) { opt := heartbeat.WithSanitization(heartbeat.SanitizeConfig{ - FilePatterns: []regex.Regex{regexp.MustCompile(".*")}, + FilePatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Category: heartbeat.CodingCategory, @@ -37,7 +38,7 @@ func TestWithSanitization_ObfuscateFile(t *testing.T) { }, nil }) - result, err := handle([]heartbeat.Heartbeat{testHeartbeat()}) + result, err := handle(context.Background(), []heartbeat.Heartbeat{testHeartbeat()}) require.NoError(t, err) assert.Equal(t, []heartbeat.Result{ @@ -48,6 +49,8 @@ func TestWithSanitization_ObfuscateFile(t *testing.T) { } func TestSanitize_Obfuscate(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Heartbeat heartbeat.Heartbeat Expected heartbeat.Heartbeat @@ -115,8 +118,8 @@ func TestSanitize_Obfuscate(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - r := heartbeat.Sanitize(test.Heartbeat, heartbeat.SanitizeConfig{ - FilePatterns: []regex.Regex{regexp.MustCompile(".*")}, + r := heartbeat.Sanitize(ctx, test.Heartbeat, heartbeat.SanitizeConfig{ + FilePatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }) assert.Equal(t, test.Expected, r) @@ -125,9 +128,9 @@ func TestSanitize_Obfuscate(t *testing.T) { } func TestSanitize_ObfuscateFile_SkipBranchIfNotMatching(t *testing.T) { - r := heartbeat.Sanitize(testHeartbeat(), heartbeat.SanitizeConfig{ - FilePatterns: []regex.Regex{regexp.MustCompile(".*")}, - BranchPatterns: []regex.Regex{regexp.MustCompile("not_matching")}, + r := heartbeat.Sanitize(context.Background(), testHeartbeat(), heartbeat.SanitizeConfig{ + FilePatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, + BranchPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("not_matching"))}, }) assert.Equal(t, heartbeat.Heartbeat{ @@ -147,9 +150,9 @@ func TestSanitize_ObfuscateFile_NilFields(t *testing.T) { h := testHeartbeat() h.Branch = nil - r := heartbeat.Sanitize(h, heartbeat.SanitizeConfig{ - FilePatterns: []regex.Regex{regexp.MustCompile(".*")}, - BranchPatterns: []regex.Regex{regexp.MustCompile(".*")}, + r := heartbeat.Sanitize(context.Background(), h, heartbeat.SanitizeConfig{ + FilePatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, + BranchPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }) assert.Equal(t, heartbeat.Heartbeat{ @@ -165,8 +168,8 @@ func TestSanitize_ObfuscateFile_NilFields(t *testing.T) { } func TestSanitize_ObfuscateProject(t *testing.T) { - r := heartbeat.Sanitize(testHeartbeat(), heartbeat.SanitizeConfig{ - ProjectPatterns: []regex.Regex{regexp.MustCompile(".*")}, + r := heartbeat.Sanitize(context.Background(), testHeartbeat(), heartbeat.SanitizeConfig{ + ProjectPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }) assert.Equal(t, heartbeat.Heartbeat{ @@ -182,9 +185,9 @@ func TestSanitize_ObfuscateProject(t *testing.T) { } func TestSanitize_ObfuscateProject_SkipBranchIfNotMatching(t *testing.T) { - r := heartbeat.Sanitize(testHeartbeat(), heartbeat.SanitizeConfig{ - ProjectPatterns: []regex.Regex{regexp.MustCompile(".*")}, - BranchPatterns: []regex.Regex{regexp.MustCompile("not_matching")}, + r := heartbeat.Sanitize(context.Background(), testHeartbeat(), heartbeat.SanitizeConfig{ + ProjectPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, + BranchPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("not_matching"))}, }) assert.Equal(t, heartbeat.Heartbeat{ @@ -204,9 +207,9 @@ func TestSanitize_ObfuscateProject_NilFields(t *testing.T) { h := testHeartbeat() h.Branch = nil - r := heartbeat.Sanitize(h, heartbeat.SanitizeConfig{ - ProjectPatterns: []regex.Regex{regexp.MustCompile(".*")}, - BranchPatterns: []regex.Regex{regexp.MustCompile(".*")}, + r := heartbeat.Sanitize(context.Background(), h, heartbeat.SanitizeConfig{ + ProjectPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, + BranchPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }) assert.Equal(t, heartbeat.Heartbeat{ @@ -222,8 +225,8 @@ func TestSanitize_ObfuscateProject_NilFields(t *testing.T) { } func TestSanitize_ObfuscateBranch(t *testing.T) { - r := heartbeat.Sanitize(testHeartbeat(), heartbeat.SanitizeConfig{ - BranchPatterns: []regex.Regex{regexp.MustCompile(".*")}, + r := heartbeat.Sanitize(context.Background(), testHeartbeat(), heartbeat.SanitizeConfig{ + BranchPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }) assert.Equal(t, heartbeat.Heartbeat{ @@ -247,8 +250,8 @@ func TestSanitize_ObfuscateBranch_NilFields(t *testing.T) { h.Branch = nil h.Project = nil - r := heartbeat.Sanitize(h, heartbeat.SanitizeConfig{ - BranchPatterns: []regex.Regex{regexp.MustCompile(".*")}, + r := heartbeat.Sanitize(context.Background(), h, heartbeat.SanitizeConfig{ + BranchPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, }) assert.Equal(t, heartbeat.Heartbeat{ @@ -267,7 +270,7 @@ func TestSanitize_ObfuscateBranch_NilFields(t *testing.T) { } func TestSanitize_EmptyConfigDoNothing(t *testing.T) { - r := heartbeat.Sanitize(testHeartbeat(), heartbeat.SanitizeConfig{}) + r := heartbeat.Sanitize(context.Background(), testHeartbeat(), heartbeat.SanitizeConfig{}) assert.Equal(t, heartbeat.Heartbeat{ Branch: heartbeat.PointerTo("heartbeat"), @@ -290,7 +293,7 @@ func TestSanitize_EmptyConfigDoNothing_EmptyDependencies(t *testing.T) { h := testHeartbeat() h.Dependencies = []string{} - r := heartbeat.Sanitize(h, heartbeat.SanitizeConfig{}) + r := heartbeat.Sanitize(context.Background(), h, heartbeat.SanitizeConfig{}) assert.Equal(t, heartbeat.Heartbeat{ Branch: heartbeat.PointerTo("heartbeat"), @@ -313,7 +316,7 @@ func TestSanitize_ObfuscateProjectFolder(t *testing.T) { h.Entity = "/path/to/project/main.go" h.ProjectPath = "/path/to" - r := heartbeat.Sanitize(h, heartbeat.SanitizeConfig{ + r := heartbeat.Sanitize(context.Background(), h, heartbeat.SanitizeConfig{ HideProjectFolder: true, }) @@ -341,7 +344,7 @@ func TestSanitize_ObfuscateProjectFolder_Override(t *testing.T) { h.ProjectPath = "/original/folder" h.ProjectPathOverride = "/path/to" - r := heartbeat.Sanitize(h, heartbeat.SanitizeConfig{ + r := heartbeat.Sanitize(context.Background(), h, heartbeat.SanitizeConfig{ HideProjectFolder: true, }) @@ -368,7 +371,7 @@ func TestSanitize_ObfuscateCredentials_RemoteFile(t *testing.T) { h := testHeartbeat() h.Entity = "ssh://wakatime:1234@192.168.1.1/path/to/remote/main.go" - r := heartbeat.Sanitize(h, heartbeat.SanitizeConfig{}) + r := heartbeat.Sanitize(context.Background(), h, heartbeat.SanitizeConfig{}) assert.Equal(t, heartbeat.Heartbeat{ Branch: heartbeat.PointerTo("heartbeat"), @@ -388,6 +391,8 @@ func TestSanitize_ObfuscateCredentials_RemoteFile(t *testing.T) { } func TestShouldSanitize(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Subject string Regex []regex.Regex @@ -396,23 +401,23 @@ func TestShouldSanitize(t *testing.T) { "match_single": { Subject: "fix.123", Regex: []regex.Regex{ - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, Expected: true, }, "match_multiple": { Subject: "fix.456", Regex: []regex.Regex{ - regexp.MustCompile("bar.*"), - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile("bar.*")), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, Expected: true, }, "not_match": { Subject: "foo", Regex: []regex.Regex{ - regexp.MustCompile("bar.*"), - regexp.MustCompile("fix.*"), + regex.NewRegexpWrap(regexp.MustCompile("bar.*")), + regex.NewRegexpWrap(regexp.MustCompile("fix.*")), }, Expected: false, }, @@ -420,7 +425,7 @@ func TestShouldSanitize(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - shouldSanitize := heartbeat.ShouldSanitize(test.Subject, test.Regex) + shouldSanitize := heartbeat.ShouldSanitize(ctx, test.Subject, test.Regex) assert.Equal(t, test.Expected, shouldSanitize) }) diff --git a/pkg/ini/ini.go b/pkg/ini/ini.go index 82d5ec40..2ceb8d54 100644 --- a/pkg/ini/ini.go +++ b/pkg/ini/ini.go @@ -1,6 +1,7 @@ package ini import ( + "context" "errors" "fmt" "os" @@ -43,7 +44,7 @@ const ( // Writer defines the methods to write to config file. type Writer interface { - Write(section string, keyValue map[string]string) error + Write(ctx context.Context, section string, keyValue map[string]string) error } // WriterConfig stores the configuration necessary to write to config file. @@ -53,15 +54,21 @@ type WriterConfig struct { } // NewWriter creates a new writer instance. -func NewWriter(v *viper.Viper, filepathFn func(v *viper.Viper) (string, error)) (*WriterConfig, error) { - configFilepath, err := filepathFn(v) +func NewWriter( + ctx context.Context, + v *viper.Viper, + filepathFn func(ctx context.Context, v *viper.Viper) (string, error), +) (*WriterConfig, error) { + configFilepath, err := filepathFn(ctx, v) if err != nil { return nil, fmt.Errorf("error getting filepath: %s", err) } + logger := log.Extract(ctx) + // check if file exists if !fileExists(configFilepath) { - log.Debugf("it will create missing config file %q", configFilepath) + logger.Debugf("it will create missing config file %q", configFilepath) f, err := os.Create(configFilepath) // nolint:gosec if err != nil { @@ -88,7 +95,9 @@ func NewWriter(v *viper.Viper, filepathFn func(v *viper.Viper) (string, error)) } // Write persists key(s) and value(s) on disk. -func (w *WriterConfig) Write(section string, keyValue map[string]string) error { +func (w *WriterConfig) Write(ctx context.Context, section string, keyValue map[string]string) error { + logger := log.Extract(ctx) + if w.File == nil || w.ConfigFilepath == "" { return errors.New("got undefined wakatime config file instance") } @@ -108,7 +117,7 @@ func (w *WriterConfig) Write(section string, keyValue map[string]string) error { Clock: &mutexClock{delay: time.Millisecond}, }) if err != nil { - log.Debugf("failed to acquire mutex: %s", err) + logger.Debugf("failed to acquire mutex: %s", err) } defer func() { @@ -137,7 +146,7 @@ func ReadInConfig(v *viper.Viper, configFilePath string) error { } // FilePath returns the path for wakatime config file. -func FilePath(v *viper.Viper) (string, error) { +func FilePath(ctx context.Context, v *viper.Viper) (string, error) { configFilepath := vipertools.GetString(v, "config") if configFilepath != "" { p, err := homedir.Expand(configFilepath) @@ -148,7 +157,7 @@ func FilePath(v *viper.Viper) (string, error) { return p, nil } - home, _, err := WakaHomeDir() + home, _, err := WakaHomeDir(ctx) if err != nil { return "", fmt.Errorf("failed getting user's home directory: %s", err) } @@ -157,7 +166,7 @@ func FilePath(v *viper.Viper) (string, error) { } // ImportFilePath returns the path for import wakatime config file. -func ImportFilePath(v *viper.Viper) (string, error) { +func ImportFilePath(_ context.Context, v *viper.Viper) (string, error) { configFilepath := vipertools.GetString(v, "settings.import_cfg") if configFilepath != "" { p, err := homedir.Expand(configFilepath) @@ -172,7 +181,7 @@ func ImportFilePath(v *viper.Viper) (string, error) { } // InternalFilePath returns the path for the wakatime internal config file. -func InternalFilePath(v *viper.Viper) (string, error) { +func InternalFilePath(ctx context.Context, v *viper.Viper) (string, error) { configFilepath := vipertools.GetString(v, "internal-config") if configFilepath != "" { p, err := homedir.Expand(configFilepath) @@ -183,7 +192,7 @@ func InternalFilePath(v *viper.Viper) (string, error) { return p, nil } - folder, err := WakaResourcesDir() + folder, err := WakaResourcesDir(ctx) if err != nil { return "", fmt.Errorf("failed getting user's home directory: %s", err) } @@ -192,7 +201,9 @@ func InternalFilePath(v *viper.Viper) (string, error) { } // WakaHomeDir returns the current user's home directory. -func WakaHomeDir() (string, WakaHomeType, error) { +func WakaHomeDir(ctx context.Context) (string, WakaHomeType, error) { + logger := log.Extract(ctx) + home, exists := os.LookupEnv("WAKATIME_HOME") if exists && home != "" { home, err := homedir.Expand(home) @@ -200,12 +211,12 @@ func WakaHomeDir() (string, WakaHomeType, error) { return home, WakaHomeTypeEnvVar, nil } - log.Warnf("failed to expand WAKATIME_HOME filepath: %s", err) + logger.Warnf("failed to expand WAKATIME_HOME filepath: %s", err) } home, err := os.UserHomeDir() if err != nil { - log.Warnf("failed to get user home dir: %s", err) + logger.Warnf("failed to get user home dir: %s", err) } if home != "" { @@ -214,7 +225,7 @@ func WakaHomeDir() (string, WakaHomeType, error) { u, err := user.LookupId(strconv.Itoa(os.Getuid())) if err != nil { - log.Warnf("failed to user info by userid: %s", err) + logger.Warnf("failed to user info by userid: %s", err) } if u.HomeDir != "" { @@ -225,8 +236,8 @@ func WakaHomeDir() (string, WakaHomeType, error) { } // WakaResourcesDir returns the ~/.wakatime/ folder. -func WakaResourcesDir() (string, error) { - home, hometype, err := WakaHomeDir() +func WakaResourcesDir(ctx context.Context) (string, error) { + home, hometype, err := WakaHomeDir(ctx) if err != nil { return "", fmt.Errorf("failed getting user's home directory: %s", err) } diff --git a/pkg/ini/ini_test.go b/pkg/ini/ini_test.go index a92f8464..3ce707c3 100644 --- a/pkg/ini/ini_test.go +++ b/pkg/ini/ini_test.go @@ -1,6 +1,7 @@ package ini_test import ( + "context" "errors" "os" "path/filepath" @@ -20,7 +21,7 @@ func TestReadInConfig(t *testing.T) { v := viper.New() v.Set("config", "testdata/wakatime.cfg") - filePath, err := ini.FilePath(v) + filePath, err := ini.FilePath(context.Background(), v) require.NoError(t, err) err = ini.ReadInConfig(v, filePath) @@ -36,11 +37,11 @@ func TestReadInConfig(t *testing.T) { func TestReadInConfig_Multiline(t *testing.T) { multilineOption := viper.IniLoadOptions(iniv1.LoadOptions{AllowPythonMultilineValues: true}) - v := viper.NewWithOptions(multilineOption) + v := viper.NewWithOptions(multilineOption) v.Set("config", "testdata/wakatime-multiline.cfg") - filePath, err := ini.FilePath(v) + filePath, err := ini.FilePath(context.Background(), v) require.NoError(t, err) err = ini.ReadInConfig(v, filePath) @@ -55,14 +56,15 @@ func TestReadInConfig_Multiline(t *testing.T) { func TestReadInConfig_Multiple(t *testing.T) { v := viper.New() - v.Set("config", "testdata/wakatime.cfg") v.Set("internal-config", "testdata/wakatime-internal.cfg") - filePath, err := ini.FilePath(v) + ctx := context.Background() + + filePath, err := ini.FilePath(ctx, v) require.NoError(t, err) - internalFilePath, err := ini.InternalFilePath(v) + internalFilePath, err := ini.InternalFilePath(ctx, v) require.NoError(t, err) err = ini.ReadInConfig(v, filePath) @@ -84,7 +86,7 @@ func TestReadInConfig_Corrupted(t *testing.T) { v.Set("config", "testdata/corrupted.cfg") - filePath, err := ini.FilePath(v) + filePath, err := ini.FilePath(context.Background(), v) require.NoError(t, err) err = ini.ReadInConfig(v, filePath) @@ -108,7 +110,7 @@ func TestReadInConfig_Malformed(t *testing.T) { v := viper.New() v.Set("config", "testdata/malformed.cfg") - filePath, err := ini.FilePath(v) + filePath, err := ini.FilePath(context.Background(), v) require.NoError(t, err) err = ini.ReadInConfig(v, filePath) @@ -119,6 +121,8 @@ func TestFilePath(t *testing.T) { home, err := os.UserHomeDir() require.NoError(t, err) + ctx := context.Background() + tests := map[string]struct { ViperValue string EnvVar string @@ -155,7 +159,7 @@ func TestFilePath(t *testing.T) { defer os.Unsetenv("WAKATIME_HOME") - configFilepath, err := ini.FilePath(v) + configFilepath, err := ini.FilePath(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, configFilepath) @@ -167,6 +171,8 @@ func TestInternalFilePath(t *testing.T) { home, err := os.UserHomeDir() require.NoError(t, err) + ctx := context.Background() + tests := map[string]struct { ViperValue string EnvVar string @@ -195,7 +201,7 @@ func TestInternalFilePath(t *testing.T) { defer os.Unsetenv("WAKATIME_HOME") - configFilepath, err := ini.InternalFilePath(v) + configFilepath, err := ini.InternalFilePath(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, configFilepath) @@ -206,7 +212,7 @@ func TestInternalFilePath(t *testing.T) { func TestNewWriter(t *testing.T) { v := viper.New() - w, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + w, err := ini.NewWriter(context.Background(), v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return "testdata/wakatime.cfg", nil }) @@ -219,7 +225,7 @@ func TestNewWriter(t *testing.T) { func TestNewWriterErr(t *testing.T) { v := viper.New() - _, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + _, err := ini.NewWriter(context.Background(), v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return "", errors.New("error") }) @@ -233,7 +239,7 @@ func TestNewWriter_MissingFile(t *testing.T) { tmpDir := t.TempDir() - w, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + w, err := ini.NewWriter(context.Background(), v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return filepath.Join(tmpDir, "missing.cfg"), nil }) @@ -248,7 +254,7 @@ func TestNewWriter_MissingFile(t *testing.T) { func TestNewWriter_CorruptedFile(t *testing.T) { v := viper.New() - w, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + w, err := ini.NewWriter(context.Background(), v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return "testdata/corrupted.cfg", nil }) @@ -264,6 +270,8 @@ func TestWrite(t *testing.T) { defer tmpFile.Close() + ctx := context.Background() + tests := map[string]struct { Value map[string]string Section string @@ -290,7 +298,7 @@ func TestWrite(t *testing.T) { ConfigFilepath: tmpFile.Name(), } - err := w.Write(test.Section, test.Value) + err := w.Write(ctx, test.Section, test.Value) require.NoError(t, err) }) @@ -298,25 +306,27 @@ func TestWrite(t *testing.T) { } func TestWrite_NoMultilineSideEffects(t *testing.T) { - multilineOption := viper.IniLoadOptions(iniv1.LoadOptions{AllowPythonMultilineValues: true}) - v := viper.NewWithOptions(multilineOption) - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + ctx := context.Background() + + multilineOption := viper.IniLoadOptions(iniv1.LoadOptions{AllowPythonMultilineValues: true}) + + v := viper.NewWithOptions(multilineOption) v.Set("config", tmpFile.Name()) copyFile(t, "testdata/wakatime-multiline.cfg", tmpFile.Name()) - w, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + w, err := ini.NewWriter(ctx, v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return tmpFile.Name(), nil }) require.NoError(t, err) - err = w.Write("settings", map[string]string{"debug": "true"}) + err = w.Write(ctx, "settings", map[string]string{"debug": "true"}) require.NoError(t, err) actual, err := os.ReadFile(tmpFile.Name()) @@ -331,25 +341,27 @@ func TestWrite_NoMultilineSideEffects(t *testing.T) { } func TestWrite_NullsRemoved(t *testing.T) { - multilineOption := viper.IniLoadOptions(iniv1.LoadOptions{AllowPythonMultilineValues: true}) - v := viper.NewWithOptions(multilineOption) - tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime") require.NoError(t, err) defer tmpFile.Close() + ctx := context.Background() + + multilineOption := viper.IniLoadOptions(iniv1.LoadOptions{AllowPythonMultilineValues: true}) + + v := viper.NewWithOptions(multilineOption) v.Set("config", tmpFile.Name()) copyFile(t, "testdata/wakatime-nulls.cfg", tmpFile.Name()) - w, err := ini.NewWriter(v, func(vp *viper.Viper) (string, error) { + w, err := ini.NewWriter(ctx, v, func(_ context.Context, vp *viper.Viper) (string, error) { assert.Equal(t, v, vp) return tmpFile.Name(), nil }) require.NoError(t, err) - err = w.Write("settings", map[string]string{"debug": "true"}) + err = w.Write(ctx, "settings", map[string]string{"debug": "true"}) require.NoError(t, err) actual, err := os.ReadFile(tmpFile.Name()) @@ -366,7 +378,7 @@ func TestWrite_NullsRemoved(t *testing.T) { func TestWriteErr(t *testing.T) { w := ini.WriterConfig{} - err := w.Write("settings", map[string]string{"debug": "true"}) + err := w.Write(context.Background(), "settings", map[string]string{"debug": "true"}) require.Error(t, err) assert.Equal(t, "got undefined wakatime config file instance", err.Error()) diff --git a/pkg/language/chroma.go b/pkg/language/chroma.go index 223fec95..b9ee1567 100644 --- a/pkg/language/chroma.go +++ b/pkg/language/chroma.go @@ -1,9 +1,10 @@ package language import ( + "context" "fmt" "io" - fp "path/filepath" + "path/filepath" "sort" "strings" @@ -23,9 +24,11 @@ const maxFileSize = 512000 // by customized priority. // If guessLanguage is true, the file content will be used to detect the language. // This is a modified implementation of chroma.lexers.internal.api:Match(). -func detectChromaCustomized(filepath string, guessLanguage bool) (heartbeat.Language, float32, bool) { - _, file := fp.Split(filepath) - filename := fp.Base(file) +func detectChromaCustomized(ctx context.Context, fp string, guessLanguage bool) (heartbeat.Language, float32, bool) { + logger := log.Extract(ctx) + + _, file := filepath.Split(fp) + filename := filepath.Base(file) matched := chroma.PrioritisedLexers{} // First, try primary filename matches. @@ -39,11 +42,11 @@ func detectChromaCustomized(filepath string, guessLanguage bool) (heartbeat.Lang } if len(matched) > 0 { - bestLexer, weight := selectByCustomizedPriority(filepath, matched) + bestLexer, weight := selectByCustomizedPriority(ctx, fp, matched) language, ok := heartbeat.ParseLanguageFromChroma(bestLexer.Config().Name) if !ok { - log.Warnf("failed to parse language from chroma lexer name %q", bestLexer.Config().Name) + logger.Warnf("failed to parse language from chroma lexer name %q", bestLexer.Config().Name) return heartbeat.LanguageUnknown, 0, false } @@ -61,11 +64,11 @@ func detectChromaCustomized(filepath string, guessLanguage bool) (heartbeat.Lang } if len(matched) > 0 { - bestLexer, weight := selectByCustomizedPriority(filepath, matched) + bestLexer, weight := selectByCustomizedPriority(ctx, fp, matched) language, ok := heartbeat.ParseLanguageFromChroma(bestLexer.Config().Name) if !ok { - log.Warnf("failed to parse language from chroma lexer name %q", bestLexer.Config().Name) + logger.Warnf("failed to parse language from chroma lexer name %q", bestLexer.Config().Name) return heartbeat.LanguageUnknown, 0, false } @@ -77,9 +80,9 @@ func detectChromaCustomized(filepath string, guessLanguage bool) (heartbeat.Lang } // Finally, try matching by file content. - head, err := fileHead(filepath) + head, err := fileHead(ctx, fp) if err != nil { - log.Warnf("failed to load head from file %q: %s", filepath, err) + logger.Warnf("failed to load head from file %q: %s", fp, err) return heartbeat.LanguageUnknown, 0, false } @@ -90,7 +93,7 @@ func detectChromaCustomized(filepath string, guessLanguage bool) (heartbeat.Lang if lexer := lexers.Analyse(string(head)); lexer != nil { language, ok := heartbeat.ParseLanguageFromChroma(lexer.Config().Name) if !ok { - log.Warnf("failed to parse language from chroma lexer name %q", lexer.Config().Name) + logger.Warnf("failed to parse language from chroma lexer name %q", lexer.Config().Name) return heartbeat.LanguageUnknown, 0, false } @@ -108,7 +111,9 @@ type weightedLexer struct { } // selectByCustomizedPriority selects the best matching lexer by customized priority evaluation. -func selectByCustomizedPriority(filepath string, lexers chroma.PrioritisedLexers) (chroma.Lexer, float32) { +func selectByCustomizedPriority(ctx context.Context, fp string, lexers chroma.PrioritisedLexers) (chroma.Lexer, float32) { + logger := log.Extract(ctx) + sort.Slice(lexers, func(i, j int) bool { icfg, jcfg := lexers[i].Config(), lexers[j].Config() @@ -121,16 +126,16 @@ func selectByCustomizedPriority(filepath string, lexers chroma.PrioritisedLexers return strings.ToLower(icfg.Name) > strings.ToLower(jcfg.Name) }) - dir, _ := fp.Split(filepath) + dir, _ := filepath.Split(fp) extensions, err := loadFolderExtensions(dir) if err != nil { - log.Warnf("failed to load folder files extensions: %s", err) + logger.Warnf("failed to load folder files extensions: %s", err) } - head, err := fileHead(filepath) + head, err := fileHead(ctx, fp) if err != nil { - log.Warnf("failed to load head from file %q: %s", filepath, err) + logger.Warnf("failed to load head from file %q: %s", fp, err) } var weighted []weightedLexer @@ -200,7 +205,9 @@ func selectByCustomizedPriority(filepath string, lexers chroma.PrioritisedLexers } // fileHead returns the first `maxFileSize` bytes of the file's content. -func fileHead(filepath string) ([]byte, error) { +func fileHead(ctx context.Context, filepath string) ([]byte, error) { + logger := log.Extract(ctx) + f, err := file.OpenNoLock(filepath) // nolint:gosec if err != nil { return nil, fmt.Errorf("failed to open file: %s", err) @@ -208,7 +215,7 @@ func fileHead(filepath string) ([]byte, error) { defer func() { if err := f.Close(); err != nil { - log.Debugf("failed to close file '%s': %s", filepath, err) + logger.Debugf("failed to close file '%s': %s", filepath, err) } }() diff --git a/pkg/language/language.go b/pkg/language/language.go index 5101aea0..77fbfec8 100644 --- a/pkg/language/language.go +++ b/pkg/language/language.go @@ -1,6 +1,7 @@ package language import ( + "context" "fmt" "os" "path/filepath" @@ -21,8 +22,9 @@ type Config struct { // language info to heartbeats of entity type 'file'. func WithDetection(config Config) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute language detection") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute language detection") for n, h := range hh { if hh[n].Language != nil { @@ -35,7 +37,7 @@ func WithDetection(config Config) heartbeat.HandleOption { filepath = h.LocalFile } - language, err := Detect(filepath, config.GuessLanguage) + language, err := Detect(ctx, filepath, config.GuessLanguage) if err != nil && hh[n].LanguageAlternate != "" { hh[n].Language = heartbeat.PointerTo(hh[n].LanguageAlternate) @@ -43,7 +45,7 @@ func WithDetection(config Config) heartbeat.HandleOption { } if err != nil { - log.Debugf("failed to detect language on file entity %q: %s", h.Entity, err) + logger.Debugf("failed to detect language on file entity %q: %s", h.Entity, err) continue } @@ -51,21 +53,21 @@ func WithDetection(config Config) heartbeat.HandleOption { hh[n].Language = heartbeat.PointerTo(language.String()) } - return next(hh) + return next(ctx, hh) } } } // Detect detects the language of a specific file. If guessLanguage is true, // Chroma will be used to detect a language from the file contents. -func Detect(fp string, guessLanguage bool) (heartbeat.Language, error) { - if language, ok := detectSpecialCases(fp); ok { +func Detect(ctx context.Context, fp string, guessLanguage bool) (heartbeat.Language, error) { + if language, ok := detectSpecialCases(ctx, fp); ok { return language, nil } var language heartbeat.Language - languageChroma, weight, ok := detectChromaCustomized(fp, guessLanguage) + languageChroma, weight, ok := detectChromaCustomized(ctx, fp, guessLanguage) if ok { language = languageChroma } @@ -84,7 +86,7 @@ func Detect(fp string, guessLanguage bool) (heartbeat.Language, error) { } // detectSpecialCases detects the language by file extension for some special cases. -func detectSpecialCases(fp string) (heartbeat.Language, bool) { +func detectSpecialCases(ctx context.Context, fp string) (heartbeat.Language, bool) { dir, file := filepath.Split(fp) ext := strings.ToLower(filepath.Ext(file)) @@ -109,11 +111,11 @@ func detectSpecialCases(fp string) (heartbeat.Language, bool) { return heartbeat.LanguageObjectiveCPP, true } - if folderContainsCPPFiles(dir) { + if folderContainsCPPFiles(ctx, dir) { return heartbeat.LanguageCPP, true } - if folderContainsCFiles(dir) { + if folderContainsCFiles(ctx, dir) { return heartbeat.LanguageC, true } } @@ -130,10 +132,12 @@ func detectSpecialCases(fp string) (heartbeat.Language, bool) { } // folderContainsCFiles returns true, if filder contains c files. -func folderContainsCFiles(dir string) bool { +func folderContainsCFiles(ctx context.Context, dir string) bool { + logger := log.Extract(ctx) + extensions, err := loadFolderExtensions(dir) if err != nil { - log.Warnf("failed loading folder extensions: %s", err) + logger.Warnf("failed loading folder extensions: %s", err) return false } @@ -147,10 +151,12 @@ func folderContainsCFiles(dir string) bool { } // folderContainsCFiles returns true, if filder contains c++ files. -func folderContainsCPPFiles(dir string) bool { +func folderContainsCPPFiles(ctx context.Context, dir string) bool { + logger := log.Extract(ctx) + extensions, err := loadFolderExtensions(dir) if err != nil { - log.Warnf("failed loading folder extensions: %s", err) + logger.Warnf("failed loading folder extensions: %s", err) return false } diff --git a/pkg/language/language_test.go b/pkg/language/language_test.go index 7fa02f59..c26dc402 100644 --- a/pkg/language/language_test.go +++ b/pkg/language/language_test.go @@ -1,6 +1,7 @@ package language_test import ( + "context" "fmt" "testing" @@ -15,7 +16,7 @@ import ( func TestWithDetection(t *testing.T) { opt := language.WithDetection(language.Config{}) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh, 1) assert.Equal(t, heartbeat.LanguageGo.String(), *hh[0].Language) assert.Equal(t, []heartbeat.Heartbeat{ @@ -33,7 +34,7 @@ func TestWithDetection(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{ + result, err := h(context.Background(), []heartbeat.Heartbeat{ { Entity: "testdata/codefiles/golang.go", EntityType: heartbeat.FileType, @@ -51,7 +52,7 @@ func TestWithDetection(t *testing.T) { func TestWithDetection_Override(t *testing.T) { opt := language.WithDetection(language.Config{}) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh, 1) assert.Equal(t, heartbeat.LanguagePython.String(), *hh[0].Language) assert.Equal(t, []heartbeat.Heartbeat{ @@ -69,7 +70,7 @@ func TestWithDetection_Override(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{ + result, err := h(context.Background(), []heartbeat.Heartbeat{ { Entity: "testdata/codefiles/golang.go", EntityType: heartbeat.FileType, @@ -88,7 +89,7 @@ func TestWithDetection_Override(t *testing.T) { func TestWithDetection_NonExistingEntity_Override(t *testing.T) { opt := language.WithDetection(language.Config{}) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh, 1) assert.Equal(t, heartbeat.LanguagePython.String(), hh[0].LanguageAlternate) assert.Equal(t, []heartbeat.Heartbeat{ @@ -107,7 +108,7 @@ func TestWithDetection_NonExistingEntity_Override(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{ + result, err := h(context.Background(), []heartbeat.Heartbeat{ { Entity: "nonexisting", EntityType: heartbeat.FileType, @@ -126,7 +127,7 @@ func TestWithDetection_NonExistingEntity_Override(t *testing.T) { func TestWithDetection_Alternate(t *testing.T) { opt := language.WithDetection(language.Config{}) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh, 1) assert.Equal(t, []heartbeat.Heartbeat{ { @@ -144,7 +145,7 @@ func TestWithDetection_Alternate(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{ + result, err := h(context.Background(), []heartbeat.Heartbeat{ { Entity: "testdata/codefiles/unknown.xyz", EntityType: heartbeat.FileType, @@ -161,96 +162,96 @@ func TestWithDetection_Alternate(t *testing.T) { } func TestDetect_HeaderFile_Corresponding_C_File(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/h_with_c_file/empty.h", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/h_with_c_file/empty.h", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageC, lang) } func TestDetect_HeaderFile_With_C_Files(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/h_with_any_c_file/empty.h", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/h_with_any_c_file/empty.h", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageC, lang) } func TestDetect_HeaderFile_With_C_And_CPP_Files(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/h_with_any_c_and_cpp_files/cpp.h", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/h_with_any_c_and_cpp_files/cpp.h", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageCPP, lang) } func TestDetect_HeaderFile_With_C_And_CXX_Files(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/h_with_any_c_and_cxx_files/cpp.h", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/h_with_any_c_and_cxx_files/cpp.h", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageCPP, lang) } func TestDetect_ObjectiveC_Over_Matlab_MatchingHeader(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/with_mat_file/objective-c.m", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/with_mat_file/objective-c.m", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageObjectiveC, lang) } func TestDetect_ObjectiveC_M_FileInFolder(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/with_mat_file/objective-c.h", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/with_mat_file/objective-c.h", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageObjectiveC, lang) } func TestDetect_ObjectiveCPP_MatchingHeader(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/with_mat_file/objective-cpp.mm", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/with_mat_file/objective-cpp.mm", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageObjectiveCPP, lang) } func TestDetect_ObjectiveCPP_MM_FileInFolder(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/with_mat_file/objective-cpp.h", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/with_mat_file/objective-cpp.h", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageObjectiveCPP, lang) } func TestDetect_ObjectiveC(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/objective-c.m", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/objective-c.m", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageObjectiveC, lang) } func TestDetect_Matlab_Over_ObjectiveC_Mat_FileInFolder(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/with_mat_file/empty.m", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/with_mat_file/empty.m", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageMatlab, lang) } func TestDetect_ObjectiveC_Over_Matlab_NonMatchingHeader(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/matlab_with_headers/empty.m", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/matlab_with_headers/empty.m", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageObjectiveC, lang) } func TestDetect_NonHeaderFile_C_FilesInFolder(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/py_with_c_files/see.py", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/py_with_c_files/see.py", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguagePython, lang) } func TestDetect_Perl_Over_Prolog(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/perl.pl", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/perl.pl", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguagePerl, lang) } func TestDetect_FSharp_Over_Forth(t *testing.T) { - lang, err := language.Detect("testdata/codefiles/fsharp.fs", false) + lang, err := language.Detect(context.Background(), "testdata/codefiles/fsharp.fs", false) require.NoError(t, err) assert.Equal(t, heartbeat.LanguageFSharp, lang) @@ -260,6 +261,8 @@ func TestDetect_ChromaTopLanguagesRetrofit(t *testing.T) { err := lexer.RegisterAll() require.NoError(t, err) + ctx := context.Background() + tests := map[string]struct { Filepaths []string GuessLanguage bool @@ -908,7 +911,7 @@ func TestDetect_ChromaTopLanguagesRetrofit(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { for _, filepath := range test.Filepaths { - lang, err := language.Detect(filepath, test.GuessLanguage) + lang, err := language.Detect(ctx, filepath, test.GuessLanguage) require.NoError(t, err) assert.Equal(t, test.Expected, lang, fmt.Sprintf("Got: %q, want: %q", lang, test.Expected)) diff --git a/pkg/lexer/ruby.go b/pkg/lexer/ruby.go index 2667a051..a8ca6904 100644 --- a/pkg/lexer/ruby.go +++ b/pkg/lexer/ruby.go @@ -2,7 +2,6 @@ package lexer import ( "github.com/wakatime/wakatime-cli/pkg/heartbeat" - "github.com/wakatime/wakatime-cli/pkg/log" "github.com/alecthomas/chroma/v2/lexers" ) @@ -13,13 +12,11 @@ func init() { lexer := lexers.Get(language) if lexer == nil { - log.Debugf("lexer %q not found", language) return } cfg := lexer.Config() if cfg == nil { - log.Debugf("lexer %q config not found", language) return } diff --git a/pkg/log/context.go b/pkg/log/context.go new file mode 100644 index 00000000..55e18c6f --- /dev/null +++ b/pkg/log/context.go @@ -0,0 +1,39 @@ +package log + +import "context" + +type ( + ctxMarker struct{} + + ctxLogger struct { + logger *Logger + } +) + +// nolint:gochecknoglobals +var ctxMarkerKey = &ctxMarker{} + +// Extract takes the call-scoped Logger. +func Extract(ctx context.Context) *Logger { + l, ok := ctx.Value(ctxMarkerKey).(*ctxLogger) + if !ok || l == nil { + return New(false, false, false) + } + + return l.logger +} + +// ToContext adds the log.Logger to the context for extraction later. +// Returning the new context that has been created. +func ToContext(ctx context.Context, logger *Logger) context.Context { + l := &ctxLogger{ + logger: logger, + } + + return context.WithValue(ctx, ctxMarkerKey, l) +} + +// AddField adds a field to the context logger. +func AddField(ctx context.Context, key string, value any) { + Extract(ctx).WithField(key, value) +} diff --git a/pkg/log/log.go b/pkg/log/log.go index b74e216b..e625fa04 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -9,43 +9,40 @@ import ( "github.com/wakatime/wakatime-cli/pkg/version" - l "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" jww "github.com/spf13/jwalterweatherman" ) -// nolint:gochecknoglobals -var ( - logEntry = new() - // Debugf logs a message at level Debug. - Debugf = logEntry.Debugf - // Infof logs a message at level Info. - Infof = logEntry.Infof - // Warnf logs a message at level Warn. - Warnf = logEntry.Warnf - // Errorf logs a message at level Error. - Errorf = logEntry.Errorf - // Fatalf logs a message at level Fatal then the process will exit with status set to 1. - Fatalf = logEntry.Fatalf - // Debugln logs a message at level Debug. - Debugln = logEntry.Debugln - // Infoln logs a message at level Info. - Infoln = logEntry.Infoln - // Warnln logs a message at level Warn. - Warnln = logEntry.Warnln - // Errorln logs a message at level Error. - Errorln = logEntry.Errorln - // Fatalln logs a message at level Fatal then the process will exit with status set to 1. - Fatalln = logEntry.Fatalln -) +// Logger is the log entry. +type Logger struct { + entry *logrus.Entry + metrics bool + sendDiagsOnErrors bool + verbose bool +} + +// New creates a new Logger. +func New(verbose, sendDiagsOnErrors, metrics bool) *Logger { + logger := &Logger{ + entry: new(), + metrics: metrics, + sendDiagsOnErrors: sendDiagsOnErrors, + verbose: verbose, + } + + logger.SetVerbose(verbose) -func new() *l.Entry { - entry := l.NewEntry(&l.Logger{ + return logger +} + +func new() *logrus.Entry { + entry := logrus.NewEntry(&logrus.Logger{ Out: os.Stdout, - Formatter: &l.JSONFormatter{ - FieldMap: l.FieldMap{ - l.FieldKeyTime: "now", - l.FieldKeyFile: "caller", - l.FieldKeyMsg: "message", + Formatter: &logrus.JSONFormatter{ + FieldMap: logrus.FieldMap{ + logrus.FieldKeyTime: "now", + logrus.FieldKeyFile: "caller", + logrus.FieldKeyMsg: "message", }, DisableHTMLEscape: true, CallerPrettyfier: func(f *runtime.Frame) (string, string) { @@ -67,7 +64,7 @@ func new() *l.Entry { fmt.Sprintf("%s:%d", file, f.Line) }, }, - Level: l.InfoLevel, + Level: logrus.InfoLevel, ExitFunc: os.Exit, ReportCaller: true, }) @@ -77,22 +74,50 @@ func new() *l.Entry { return entry } +// IsMetricsEnabled returns true if it should collect metrics. +func (l *Logger) IsMetricsEnabled() bool { + return l.metrics +} + +// IsVerboseEnabled returns true if debug is enabled. +func (l *Logger) IsVerboseEnabled() bool { + return l.verbose +} + // Output returns the current log output. -func Output() io.Writer { - return logEntry.Logger.Out +func (l *Logger) Output() io.Writer { + return l.entry.Logger.Out +} + +// SendDiagsOnErrors returns true if diagnostics should be sent on errors. +func (l *Logger) SendDiagsOnErrors() bool { + return l.sendDiagsOnErrors } // SetOutput defines sets the log output to io.Writer. -func SetOutput(w io.Writer) { - logEntry.Logger.Out = w +func (l *Logger) SetOutput(w io.Writer) { + l.entry.Logger.Out = w } // SetVerbose sets log level to debug if enabled. -func SetVerbose(verbose bool) { +func (l *Logger) SetVerbose(verbose bool) { if verbose { - logEntry.Logger.SetLevel(l.DebugLevel) + l.entry.Logger.SetLevel(logrus.DebugLevel) } else { - logEntry.Logger.SetLevel(l.InfoLevel) + l.entry.Logger.SetLevel(logrus.InfoLevel) + } +} + +// Flush flushes the log output and closes the file. +func (l *Logger) Flush() { + if file, ok := l.entry.Logger.Out.(*os.File); ok { + if err := file.Sync(); err != nil { + l.entry.Debugf("failed to flush log file: %s", err) + } + + if err := file.Close(); err != nil { + l.entry.Debugf("failed to close log file: %s", err) + } } } @@ -107,7 +132,57 @@ func SetJww(verbose bool, w io.Writer) { } } -// WithField adds a single field to the Entry. -func WithField(key string, value any) { - logEntry.Data[key] = value +// Debugf logs a message at level Debug. +func (l *Logger) Debugf(format string, args ...any) { + l.entry.Debugf(format, args...) +} + +// Infof logs a message at level Info. +func (l *Logger) Infof(format string, args ...any) { + l.entry.Infof(format, args...) +} + +// Warnf logs a message at level Warn. +func (l *Logger) Warnf(format string, args ...any) { + l.entry.Warnf(format, args...) +} + +// Errorf logs a message at level Error. +func (l *Logger) Errorf(format string, args ...any) { + l.entry.Errorf(format, args...) +} + +// Fatalf logs a message at level Fatal then the process will exit with status set to 1. +func (l *Logger) Fatalf(format string, args ...any) { + l.entry.Fatalf(format, args...) +} + +// Debugln logs a message at level Debug. +func (l *Logger) Debugln(args ...any) { + l.entry.Debugln(args...) +} + +// Infoln logs a message at level Info. +func (l *Logger) Infoln(args ...any) { + l.entry.Infoln(args...) +} + +// Warnln logs a message at level Warn. +func (l *Logger) Warnln(args ...any) { + l.entry.Warnln(args...) +} + +// Errorln logs a message at level Error. +func (l *Logger) Errorln(args ...any) { + l.entry.Errorln(args...) +} + +// Fatalln logs a message at level Fatal then the process will exit with status set to 1. +func (l *Logger) Fatalln(args ...any) { + l.entry.Fatalln(args...) +} + +// WithField adds a single field to the Logger. +func (l *Logger) WithField(key string, value any) { + l.entry.Data[key] = value } diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go new file mode 100644 index 00000000..1390ea26 --- /dev/null +++ b/pkg/log/log_test.go @@ -0,0 +1,45 @@ +package log_test + +import ( + "testing" + + "github.com/wakatime/wakatime-cli/pkg/log" + + "github.com/stretchr/testify/assert" +) + +func TestLog_IsMetricsEnabled(t *testing.T) { + logger := log.New(false, false, true) + + assert.True(t, logger.IsMetricsEnabled()) +} + +func TestLog_IsMetricsEnabled_Disabled(t *testing.T) { + logger := log.New(false, false, false) + + assert.False(t, logger.IsMetricsEnabled()) +} + +func TestLog_IsVerboseEnabled(t *testing.T) { + logger := log.New(true, false, false) + + assert.True(t, logger.IsVerboseEnabled()) +} + +func TestLog_IsVerboseEnabled_Disabled(t *testing.T) { + logger := log.New(false, false, false) + + assert.False(t, logger.IsVerboseEnabled()) +} + +func TestLog_SendDiagsOnErrors(t *testing.T) { + logger := log.New(false, true, false) + + assert.True(t, logger.SendDiagsOnErrors()) +} + +func TestLog_SendDiagsOnErrors_Disabled(t *testing.T) { + logger := log.New(false, false, false) + + assert.False(t, logger.SendDiagsOnErrors()) +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index c5a4b7de..1004f9ff 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -1,6 +1,7 @@ package metrics import ( + "context" "fmt" "os" "path/filepath" @@ -14,8 +15,8 @@ import ( // StartProfiling starts profiling cpu and memory. It returns a function that // should be called to stop profiling and close the files. -func StartProfiling() (func(), error) { - folder, err := ini.WakaResourcesDir() +func StartProfiling(ctx context.Context) (func(), error) { + folder, err := ini.WakaResourcesDir(ctx) if err != nil { return nil, fmt.Errorf("failed getting user's home directory: %s", err) } @@ -32,8 +33,10 @@ func StartProfiling() (func(), error) { return nil, fmt.Errorf("failed to create cpu profile file: %s", err) } + logger := log.Extract(ctx) + if err := pprof.StartCPUProfile(cpuf); err != nil { - log.Errorf("failed to start cpu profile: %s", err) + logger.Errorf("failed to start cpu profile: %s", err) } memf, err := os.Create(filepath.Join(metricsFolder, fmt.Sprintf("mem_%s.profile", now))) // nolint:gosec @@ -42,18 +45,18 @@ func StartProfiling() (func(), error) { } if err := pprof.WriteHeapProfile(memf); err != nil { - log.Errorf("failed to write heap profile: %s", err) + logger.Errorf("failed to write heap profile: %s", err) } return func() { pprof.StopCPUProfile() if err := cpuf.Close(); err != nil { - log.Errorf("failed to close cpu profile file: %s", err) + logger.Errorf("failed to close cpu profile file: %s", err) } if err := memf.Close(); err != nil { - log.Errorf("failed to close mem profile file: %s", err) + logger.Errorf("failed to close mem profile file: %s", err) } }, nil } diff --git a/pkg/offline/legacy.go b/pkg/offline/legacy.go index 8e2d859e..7320b02c 100644 --- a/pkg/offline/legacy.go +++ b/pkg/offline/legacy.go @@ -1,6 +1,7 @@ package offline import ( + "context" "fmt" "path/filepath" @@ -17,7 +18,7 @@ const dbLegacyFilename = ".wakatime.bdb" // the user's $HOME folder cannot be detected, it defaults to the // current directory. // This is used to support the old db file name and will be removed in the future. -func QueueFilepathLegacy(v *viper.Viper) (string, error) { +func QueueFilepathLegacy(ctx context.Context, v *viper.Viper) (string, error) { paramFile := vipertools.GetString(v, "offline-queue-file-legacy") if paramFile != "" { p, err := homedir.Expand(paramFile) @@ -28,7 +29,7 @@ func QueueFilepathLegacy(v *viper.Viper) (string, error) { return p, nil } - home, _, err := ini.WakaHomeDir() + home, _, err := ini.WakaHomeDir(ctx) if err != nil { return dbFilename, fmt.Errorf("failed getting user's home directory, defaulting to current directory: %s", err) } diff --git a/pkg/offline/legacy_test.go b/pkg/offline/legacy_test.go index c2784549..290062ce 100644 --- a/pkg/offline/legacy_test.go +++ b/pkg/offline/legacy_test.go @@ -1,6 +1,7 @@ package offline_test import ( + "context" "os" "path/filepath" "testing" @@ -16,6 +17,8 @@ func TestQueueFilepathLegacy(t *testing.T) { home, err := os.UserHomeDir() require.NoError(t, err) + ctx := context.Background() + tests := map[string]struct { ViperValue string EnvVar string @@ -42,7 +45,7 @@ func TestQueueFilepathLegacy(t *testing.T) { defer os.Unsetenv("WAKATIME_HOME") v := viper.New() - queueFilepath, err := offline.QueueFilepathLegacy(v) + queueFilepath, err := offline.QueueFilepathLegacy(ctx, v) require.NoError(t, err) assert.Equal(t, test.Expected, queueFilepath) diff --git a/pkg/offline/offline.go b/pkg/offline/offline.go index cd6e44f1..ffd93c4a 100644 --- a/pkg/offline/offline.go +++ b/pkg/offline/offline.go @@ -1,6 +1,7 @@ package offline import ( + "context" "encoding/json" "errors" "fmt" @@ -45,7 +46,7 @@ const ( type Noop struct{} // SendHeartbeats always returns an error. -func (Noop) SendHeartbeats(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { +func (Noop) SendHeartbeats(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return nil, api.Err{Err: errors.New("skip sending heartbeats and only save to offline db")} } @@ -57,20 +58,21 @@ func (Noop) SendHeartbeats(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) // at next usages of the wakatime cli. func WithQueue(filepath string) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugf("execute offline queue with file %s", filepath) + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugf("execute offline queue with file %s", filepath) if len(hh) == 0 { - log.Debugln("abort execution, as there are no heartbeats ready for sending") + logger.Debugln("abort execution, as there are no heartbeats ready for sending") return nil, nil } - results, err := next(hh) + results, err := next(ctx, hh) if err != nil { - log.Debugf("pushing %d heartbeat(s) to queue after error: %s", len(hh), err) + logger.Debugf("pushing %d heartbeat(s) to queue after error: %s", len(hh), err) - requeueErr := pushHeartbeatsWithRetry(filepath, hh) + requeueErr := pushHeartbeatsWithRetry(ctx, filepath, hh) if requeueErr != nil { return nil, fmt.Errorf( "failed to push heartbeats to queue: %s", @@ -81,7 +83,7 @@ func WithQueue(filepath string) heartbeat.HandleOption { return nil, err } - err = handleResults(filepath, results, hh) + err = handleResults(ctx, filepath, results, hh) if err != nil { return nil, fmt.Errorf("failed to handle results: %s", err) } @@ -94,7 +96,7 @@ func WithQueue(filepath string) heartbeat.HandleOption { // QueueFilepath returns the path for offline queue db file. If // the resource directory cannot be detected, it defaults to the // current directory. -func QueueFilepath(v *viper.Viper) (string, error) { +func QueueFilepath(ctx context.Context, v *viper.Viper) (string, error) { paramFile := vipertools.GetString(v, "offline-queue-file") if paramFile != "" { p, err := homedir.Expand(paramFile) @@ -105,7 +107,7 @@ func QueueFilepath(v *viper.Viper) (string, error) { return p, nil } - folder, err := ini.WakaResourcesDir() + folder, err := ini.WakaResourcesDir(ctx) if err != nil { return dbFilename, fmt.Errorf("failed getting resource directory, defaulting to current directory: %s", err) } @@ -118,10 +120,11 @@ func QueueFilepath(v *viper.Viper) (string, error) { // from offline queue and send the heartbeats to WakaTime API. func WithSync(filepath string, syncLimit int) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugf("execute offline sync with file %s", filepath) + return func(ctx context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugf("execute offline sync with file %s", filepath) - err := Sync(filepath, syncLimit)(next) + err := Sync(ctx, filepath, syncLimit)(next) if err != nil { return nil, fmt.Errorf("failed to sync offline heartbeats: %s", err) } @@ -132,7 +135,7 @@ func WithSync(filepath string, syncLimit int) heartbeat.HandleOption { } // Sync returns a function to send queued heartbeats to the WakaTime API. -func Sync(filepath string, syncLimit int) func(next heartbeat.Handle) error { +func Sync(ctx context.Context, filepath string, syncLimit int) func(next heartbeat.Handle) error { return func(next heartbeat.Handle) error { var ( alreadySent int @@ -143,6 +146,8 @@ func Sync(filepath string, syncLimit int) func(next heartbeat.Handle) error { syncLimit = math.MaxInt32 } + logger := log.Extract(ctx) + for { run++ @@ -157,30 +162,30 @@ func Sync(filepath string, syncLimit int) func(next heartbeat.Handle) error { alreadySent += num } - hh, err := popHeartbeats(filepath, num) + hh, err := popHeartbeats(ctx, filepath, num) if err != nil { return fmt.Errorf("failed to fetch heartbeat from offline queue: %s", err) } if len(hh) == 0 { - log.Debugln("no queued heartbeats ready for sending") + logger.Debugln("no queued heartbeats ready for sending") break } - log.Debugf("send %d heartbeats on sync run %d", len(hh), run) + logger.Debugf("send %d heartbeats on sync run %d", len(hh), run) - results, err := next(hh) + results, err := next(ctx, hh) if err != nil { - requeueErr := pushHeartbeatsWithRetry(filepath, hh) + requeueErr := pushHeartbeatsWithRetry(ctx, filepath, hh) if requeueErr != nil { - log.Warnf("failed to push heartbeats to queue after api error: %s", requeueErr) + logger.Warnf("failed to push heartbeats to queue after api error: %s", requeueErr) } return err } - err = handleResults(filepath, results, hh) + err = handleResults(ctx, filepath, results, hh) if err != nil { return fmt.Errorf("failed to handle heartbeats api results: %s", err) } @@ -190,30 +195,32 @@ func Sync(filepath string, syncLimit int) func(next heartbeat.Handle) error { } } -func handleResults(filepath string, results []heartbeat.Result, hh []heartbeat.Heartbeat) error { +func handleResults(ctx context.Context, filepath string, results []heartbeat.Result, hh []heartbeat.Heartbeat) error { var ( err error withInvalidStatus []heartbeat.Heartbeat ) + logger := log.Extract(ctx) + // push heartbeats with invalid result status codes to queue for n, result := range results { if n >= len(hh) { - log.Warnln("results from api not matching heartbeats sent") + logger.Warnln("results from api not matching heartbeats sent") break } if result.Status == http.StatusBadRequest { serialized, jsonErr := json.Marshal(result.Heartbeat) if jsonErr != nil { - log.Warnf( + logger.Warnf( "failed to json marshal heartbeat: %s. heartbeat: %#v", jsonErr, result.Heartbeat, ) } - log.Debugf("heartbeat result status bad request: %s", string(serialized)) + logger.Debugf("heartbeat result status bad request: %s", string(serialized)) continue } @@ -224,32 +231,32 @@ func handleResults(filepath string, results []heartbeat.Result, hh []heartbeat.H } if len(withInvalidStatus) > 0 { - log.Debugf("pushing %d heartbeat(s) with invalid result to queue", len(withInvalidStatus)) + logger.Debugf("pushing %d heartbeat(s) with invalid result to queue", len(withInvalidStatus)) - err = pushHeartbeatsWithRetry(filepath, withInvalidStatus) + err = pushHeartbeatsWithRetry(ctx, filepath, withInvalidStatus) if err != nil { - log.Warnf("failed to push heartbeats with invalid status to queue: %s", err) + logger.Warnf("failed to push heartbeats with invalid status to queue: %s", err) } } // handle leftover heartbeats leftovers := len(hh) - len(results) if leftovers > 0 { - log.Warnf("missing %d results from api.", leftovers) + logger.Warnf("missing %d results from api.", leftovers) start := len(hh) - leftovers - err = pushHeartbeatsWithRetry(filepath, hh[start:]) + err = pushHeartbeatsWithRetry(ctx, filepath, hh[start:]) if err != nil { - log.Warnf("failed to push leftover heartbeats to queue: %s", err) + logger.Warnf("failed to push leftover heartbeats to queue: %s", err) } } return err } -func popHeartbeats(filepath string, limit int) ([]heartbeat.Heartbeat, error) { - db, close, err := openDB(filepath) +func popHeartbeats(ctx context.Context, filepath string, limit int) ([]heartbeat.Heartbeat, error) { + db, close, err := openDB(ctx, filepath) if err != nil { return nil, err } @@ -262,12 +269,13 @@ func popHeartbeats(filepath string, limit int) ([]heartbeat.Heartbeat, error) { } queue := NewQueue(tx) + logger := log.Extract(ctx) queued, err := queue.PopMany(limit) if err != nil { errrb := tx.Rollback() if errrb != nil { - log.Errorf("failed to rollback transaction: %s", errrb) + logger.Errorf("failed to rollback transaction: %s", errrb) } return nil, fmt.Errorf("failed to pop heartbeat(s) from queue: %s", err) @@ -280,17 +288,19 @@ func popHeartbeats(filepath string, limit int) ([]heartbeat.Heartbeat, error) { return queued, nil } -func pushHeartbeatsWithRetry(filepath string, hh []heartbeat.Heartbeat) error { +func pushHeartbeatsWithRetry(ctx context.Context, filepath string, hh []heartbeat.Heartbeat) error { var ( count int err error ) + logger := log.Extract(ctx) + for { if count >= maxRequeueAttempts { serialized, jsonErr := json.Marshal(hh) if jsonErr != nil { - log.Warnf("failed to json marshal heartbeats: %s. heartbeats: %#v", jsonErr, hh) + logger.Warnf("failed to json marshal heartbeats: %s. heartbeats: %#v", jsonErr, hh) } return fmt.Errorf( @@ -301,7 +311,7 @@ func pushHeartbeatsWithRetry(filepath string, hh []heartbeat.Heartbeat) error { ) } - err = pushHeartbeats(filepath, hh) + err = pushHeartbeats(ctx, filepath, hh) if err != nil { count++ @@ -318,8 +328,8 @@ func pushHeartbeatsWithRetry(filepath string, hh []heartbeat.Heartbeat) error { return nil } -func pushHeartbeats(filepath string, hh []heartbeat.Heartbeat) error { - db, close, err := openDB(filepath) +func pushHeartbeats(ctx context.Context, filepath string, hh []heartbeat.Heartbeat) error { + db, close, err := openDB(ctx, filepath) if err != nil { return err } @@ -346,8 +356,8 @@ func pushHeartbeats(filepath string, hh []heartbeat.Heartbeat) error { } // CountHeartbeats returns the total number of heartbeats in the offline db. -func CountHeartbeats(filepath string) (int, error) { - db, close, err := openDB(filepath) +func CountHeartbeats(ctx context.Context, filepath string) (int, error) { + db, close, err := openDB(ctx, filepath) if err != nil { return 0, err } @@ -359,10 +369,12 @@ func CountHeartbeats(filepath string) (int, error) { return 0, fmt.Errorf("failed to start db transaction: %s", err) } + logger := log.Extract(ctx) + defer func() { err := tx.Rollback() if err != nil { - log.Errorf("failed to rollback transaction: %s", err) + logger.Errorf("failed to rollback transaction: %s", err) } }() @@ -377,8 +389,8 @@ func CountHeartbeats(filepath string) (int, error) { } // ReadHeartbeats reads the informed heartbeats in the offline db. -func ReadHeartbeats(filepath string, limit int) ([]heartbeat.Heartbeat, error) { - db, close, err := openDB(filepath) +func ReadHeartbeats(ctx context.Context, filepath string, limit int) ([]heartbeat.Heartbeat, error) { + db, close, err := openDB(ctx, filepath) if err != nil { return nil, err } @@ -391,10 +403,11 @@ func ReadHeartbeats(filepath string, limit int) ([]heartbeat.Heartbeat, error) { } queue := NewQueue(tx) + logger := log.Extract(ctx) hh, err := queue.ReadMany(limit) if err != nil { - log.Errorf("failed to read offline heartbeats: %s", err) + logger.Errorf("failed to read offline heartbeats: %s", err) _ = tx.Rollback() @@ -403,7 +416,7 @@ func ReadHeartbeats(filepath string, limit int) ([]heartbeat.Heartbeat, error) { err = tx.Rollback() if err != nil { - log.Warnf("failed to rollback transaction: %s", err) + logger.Warnf("failed to rollback transaction: %s", err) } return hh, nil @@ -412,7 +425,7 @@ func ReadHeartbeats(filepath string, limit int) ([]heartbeat.Heartbeat, error) { // openDB opens a connection to the offline db. // It returns the pointer to bolt.DB, a function to close the connection and an error. // Although named parameters should be avoided, this func uses them to access inside the deferred function and set an error. -func openDB(filepath string) (db *bolt.DB, _ func(), err error) { +func openDB(ctx context.Context, filepath string) (db *bolt.DB, _ func(), err error) { defer func() { if r := recover(); r != nil { err = ErrOpenDB{Err: fmt.Errorf("panicked: %v", r)} @@ -424,16 +437,18 @@ func openDB(filepath string) (db *bolt.DB, _ func(), err error) { return nil, nil, fmt.Errorf("failed to open db file: %s", err) } + logger := log.Extract(ctx) + return db, func() { // recover from panic when closing db defer func() { if r := recover(); r != nil { - log.Warnf("panicked: failed to close db file: %v", r) + logger.Warnf("panicked: failed to close db file: %v", r) } }() if err := db.Close(); err != nil { - log.Debugf("failed to close db file: %s", err) + logger.Debugf("failed to close db file: %s", err) } }, err } diff --git a/pkg/offline/offline_test.go b/pkg/offline/offline_test.go index 6d150d33..b13731a8 100644 --- a/pkg/offline/offline_test.go +++ b/pkg/offline/offline_test.go @@ -1,6 +1,7 @@ package offline_test import ( + "context" "encoding/json" "errors" "fmt" @@ -22,6 +23,8 @@ import ( ) func TestQueueFilepath(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { EnvVar string }{ @@ -41,11 +44,11 @@ func TestQueueFilepath(t *testing.T) { defer os.Unsetenv("WAKATIME_HOME") - folder, err := ini.WakaResourcesDir() + folder, err := ini.WakaResourcesDir(ctx) require.NoError(t, err) v := viper.New() - queueFilepath, err := offline.QueueFilepath(v) + queueFilepath, err := offline.QueueFilepath(ctx, v) require.NoError(t, err) expected := filepath.Join(folder, "offline_heartbeats.bdb") @@ -80,7 +83,7 @@ func TestWithQueue(t *testing.T) { opt := offline.WithQueue(f.Name()) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh, 2) assert.Contains(t, hh, testHeartbeats()[0]) assert.Contains(t, hh, testHeartbeats()[1]) @@ -98,7 +101,7 @@ func TestWithQueue(t *testing.T) { }) // run - results, err := handle([]heartbeat.Heartbeat{ + results, err := handle(context.Background(), []heartbeat.Heartbeat{ testHeartbeats()[0], testHeartbeats()[1], }) @@ -153,14 +156,14 @@ func TestWithQueue_NoHeartbeats(t *testing.T) { opt := offline.WithQueue(f.Name()) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Len(t, hh, 0) return []heartbeat.Result{}, nil }) // run - results, err := handle([]heartbeat.Heartbeat{}) + results, err := handle(context.Background(), []heartbeat.Heartbeat{}) require.NoError(t, err) // check @@ -176,7 +179,7 @@ func TestWithQueue_ApiError(t *testing.T) { opt := offline.WithQueue(f.Name()) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, hh, []heartbeat.Heartbeat{ testHeartbeats()[0], testHeartbeats()[1], @@ -186,7 +189,7 @@ func TestWithQueue_ApiError(t *testing.T) { }) // run - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(context.Background(), []heartbeat.Heartbeat{ testHeartbeats()[0], testHeartbeats()[1], }) @@ -239,7 +242,7 @@ func TestWithQueue_InvalidResults(t *testing.T) { opt := offline.WithQueue(f.Name()) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, hh, testHeartbeats()) return []heartbeat.Result{ @@ -259,7 +262,7 @@ func TestWithQueue_InvalidResults(t *testing.T) { }) // run - results, err := handle(testHeartbeats()) + results, err := handle(context.Background(), testHeartbeats()) require.NoError(t, err) // check @@ -325,7 +328,7 @@ func TestWithQueue_HandleLeftovers(t *testing.T) { opt := offline.WithQueue(f.Name()) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, hh, testHeartbeats()) return []heartbeat.Result{ @@ -337,7 +340,7 @@ func TestWithQueue_HandleLeftovers(t *testing.T) { }) // run - results, err := handle(testHeartbeats()) + results, err := handle(context.Background(), testHeartbeats()) require.NoError(t, err) // check @@ -417,7 +420,7 @@ func TestWithSync(t *testing.T) { opt := offline.WithSync(f.Name(), offline.SyncMaxDefault) - handle := opt(func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: http.StatusCreated, @@ -431,7 +434,7 @@ func TestWithSync(t *testing.T) { }) // run - results, err := handle(nil) + results, err := handle(context.Background(), nil) require.NoError(t, err) // check @@ -486,12 +489,12 @@ func TestSync_MultipleRequests(t *testing.T) { err = db.Close() require.NoError(t, err) - syncFn := offline.Sync(f.Name(), 1000) + syncFn := offline.Sync(context.Background(), f.Name(), 1000) var numCalls int // run - err = syncFn(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + err = syncFn(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { numCalls++ // first request @@ -585,12 +588,12 @@ func TestSync_APIError(t *testing.T) { err = db.Close() require.NoError(t, err) - syncFn := offline.Sync(f.Name(), 10) + syncFn := offline.Sync(context.Background(), f.Name(), 10) var numCalls int // run - err = syncFn(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + err = syncFn(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { numCalls++ assert.Equal(t, []heartbeat.Heartbeat{ @@ -673,12 +676,12 @@ func TestSync_InvalidResults(t *testing.T) { err = db.Close() require.NoError(t, err) - syncFn := offline.Sync(f.Name(), 1000) + syncFn := offline.Sync(context.Background(), f.Name(), 1000) var numCalls int // run - err = syncFn(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + err = syncFn(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { numCalls++ // first request @@ -782,12 +785,12 @@ func TestSync_SyncLimit(t *testing.T) { err = db.Close() require.NoError(t, err) - syncFn := offline.Sync(f.Name(), 1) + syncFn := offline.Sync(context.Background(), f.Name(), 1) var numCalls int // run - err = syncFn(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + err = syncFn(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { numCalls++ assert.Len(t, hh, 1) @@ -862,12 +865,12 @@ func TestSync_SyncUnlimited(t *testing.T) { err = db.Close() require.NoError(t, err) - syncFn := offline.Sync(f.Name(), 0) + syncFn := offline.Sync(context.Background(), f.Name(), 0) var numCalls int // run - err = syncFn(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + err = syncFn(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { numCalls++ assert.Len(t, hh, 2) @@ -941,7 +944,7 @@ func TestCountHeartbeats(t *testing.T) { err = db.Close() require.NoError(t, err) - count, err := offline.CountHeartbeats(f.Name()) + count, err := offline.CountHeartbeats(context.Background(), f.Name()) require.NoError(t, err) assert.Equal(t, count, 3) @@ -954,7 +957,7 @@ func TestCountHeartbeats_Empty(t *testing.T) { defer f.Close() - count, err := offline.CountHeartbeats(f.Name()) + count, err := offline.CountHeartbeats(context.Background(), f.Name()) require.NoError(t, err) assert.Equal(t, count, 0) @@ -990,7 +993,7 @@ func TestReadHeartbeats(t *testing.T) { err = db.Close() require.NoError(t, err) - hh, err := offline.ReadHeartbeats(f.Name(), offline.PrintMaxDefault) + hh, err := offline.ReadHeartbeats(context.Background(), f.Name(), offline.PrintMaxDefault) require.NoError(t, err) assert.Len(t, hh, 2) @@ -1026,7 +1029,7 @@ func TestReadHeartbeats_WithLimit(t *testing.T) { err = db.Close() require.NoError(t, err) - hh, err := offline.ReadHeartbeats(f.Name(), 1) + hh, err := offline.ReadHeartbeats(context.Background(), f.Name(), 1) require.NoError(t, err) assert.Len(t, hh, 1) @@ -1039,7 +1042,7 @@ func TestReadHeartbeats_Empty(t *testing.T) { defer f.Close() - hh, err := offline.ReadHeartbeats(f.Name(), offline.PrintMaxDefault) + hh, err := offline.ReadHeartbeats(context.Background(), f.Name(), offline.PrintMaxDefault) require.NoError(t, err) assert.Len(t, hh, 0) diff --git a/pkg/project/file.go b/pkg/project/file.go index b1bd5556..58d76338 100644 --- a/pkg/project/file.go +++ b/pkg/project/file.go @@ -2,6 +2,7 @@ package project import ( "bufio" + "context" "errors" "fmt" "os" @@ -19,15 +20,16 @@ type File struct { // Detect get information from a .wakatime-project file about the project for // a given file. First line of .wakatime-project sets the project // name. Second line sets the current branch name. -func (f File) Detect() (Result, bool, error) { - fp, found := FindFileOrDirectory(f.Filepath, WakaTimeProjectFile) +func (f File) Detect(ctx context.Context) (Result, bool, error) { + fp, found := FindFileOrDirectory(ctx, f.Filepath, WakaTimeProjectFile) if !found { return Result{}, false, nil } - log.Debugf("wakatime project file found at: %s", fp) + logger := log.Extract(ctx) + logger.Debugf("wakatime project file found at: %s", fp) - lines, err := ReadFile(fp, 2) + lines, err := ReadFile(ctx, fp, 2) if err != nil { return Result{}, false, fmt.Errorf("error reading file: %s", err) } @@ -49,7 +51,7 @@ func (f File) Detect() (Result, bool, error) { } // ReadFile reads a file until max number of lines and return an array of lines. -func ReadFile(fp string, max int) ([]string, error) { +func ReadFile(ctx context.Context, fp string, max int) ([]string, error) { if fp == "" { return nil, errors.New("filepath cannot be empty") } @@ -59,9 +61,11 @@ func ReadFile(fp string, max int) ([]string, error) { return nil, fmt.Errorf("failed while opening file %q: %s", fp, err) } + logger := log.Extract(ctx) + defer func() { if err := file.Close(); err != nil { - log.Debugf("failed to close file '%s': %s", file.Name(), err) + logger.Debugf("failed to close file '%s': %s", file.Name(), err) } }() diff --git a/pkg/project/file_test.go b/pkg/project/file_test.go index d214b3fb..e4030caa 100644 --- a/pkg/project/file_test.go +++ b/pkg/project/file_test.go @@ -1,6 +1,7 @@ package project_test import ( + "context" "os" "path/filepath" "testing" @@ -26,7 +27,7 @@ func TestFile_Detect_FileExists(t *testing.T) { Filepath: filepath.Join(tmpDir, ".wakatime-project"), } - result, detected, err := f.Detect() + result, detected, err := f.Detect(context.Background()) require.NoError(t, err) expected := project.Result{ @@ -58,7 +59,7 @@ func TestFile_Detect_ParentFolderExists(t *testing.T) { Filepath: dir, } - result, detected, err := f.Detect() + result, detected, err := f.Detect(context.Background()) require.NoError(t, err) expected := project.Result{ @@ -83,7 +84,7 @@ func TestFile_Detect_NoFileFound(t *testing.T) { Filepath: tmpDir, } - result, detected, err := f.Detect() + result, detected, err := f.Detect(context.Background()) require.NoError(t, err) expected := project.Result{} @@ -102,7 +103,7 @@ func TestFile_Detect_InvalidPath(t *testing.T) { Filepath: tmpFile.Name(), } - _, detected, err := f.Detect() + _, detected, err := f.Detect(context.Background()) require.NoError(t, err) assert.False(t, detected) @@ -122,6 +123,8 @@ func TestFindFileOrDirectory(t *testing.T) { filepath.Join(tmpDir, ".wakatime-project"), ) + ctx := context.Background() + tests := map[string]struct { Filepath string Filename string @@ -141,7 +144,7 @@ func TestFindFileOrDirectory(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - fp, ok := project.FindFileOrDirectory(test.Filepath, test.Filename) + fp, ok := project.FindFileOrDirectory(ctx, test.Filepath, test.Filename) require.True(t, ok) assert.Equal(t, test.Expected, fp) diff --git a/pkg/project/filter.go b/pkg/project/filter.go index 277ffd61..f8c8283c 100644 --- a/pkg/project/filter.go +++ b/pkg/project/filter.go @@ -1,6 +1,7 @@ package project import ( + "context" "fmt" "os" @@ -19,20 +20,21 @@ type FilterConfig struct { // the provided configurations. func WithFiltering(config FilterConfig) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute project filtering") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute project filtering") var filtered []heartbeat.Heartbeat for _, h := range hh { err := Filter(h, config) if err != nil { - log.Debugln(err.Error()) + logger.Debugln(err.Error()) if h.LocalFileNeedsCleanup { err = os.Remove(h.LocalFile) if err != nil { - log.Warnf("unable to delete tmp file: %s", err) + logger.Warnf("unable to delete tmp file: %s", err) } } @@ -42,7 +44,7 @@ func WithFiltering(config FilterConfig) heartbeat.HandleOption { filtered = append(filtered, h) } - return next(filtered) + return next(ctx, filtered) } } } diff --git a/pkg/project/filter_test.go b/pkg/project/filter_test.go index 3e240972..ce66162e 100644 --- a/pkg/project/filter_test.go +++ b/pkg/project/filter_test.go @@ -1,6 +1,7 @@ package project_test import ( + "context" "os" "testing" @@ -26,7 +27,7 @@ func TestWithFiltering(t *testing.T) { opt := project.WithFiltering(project.FilterConfig{ ExcludeUnknownProject: true, }) - h := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + h := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo("heartbeat"), @@ -52,7 +53,7 @@ func TestWithFiltering(t *testing.T) { }, nil }) - result, err := h([]heartbeat.Heartbeat{first, second}) + result, err := h(context.Background(), []heartbeat.Heartbeat{first, second}) require.NoError(t, err) assert.Equal(t, []heartbeat.Result{ diff --git a/pkg/project/git.go b/pkg/project/git.go index bd761c28..6e40345a 100644 --- a/pkg/project/git.go +++ b/pkg/project/git.go @@ -1,6 +1,7 @@ package project import ( + "context" "fmt" "path/filepath" "strings" @@ -23,7 +24,8 @@ type Git struct { // Detect gets information about the git project for a given file. // It tries to return a project and branch name. -func (g Git) Detect() (Result, bool, error) { +func (g Git) Detect(ctx context.Context) (Result, bool, error) { + logger := log.Extract(ctx) fp := g.Filepath // Take only the directory @@ -32,22 +34,22 @@ func (g Git) Detect() (Result, bool, error) { } // Find for submodule takes priority if enabled - gitdirSubmodule, ok, err := findSubmodule(fp, g.SubmoduleDisabledPatterns) + gitdirSubmodule, ok, err := findSubmodule(ctx, fp, g.SubmoduleDisabledPatterns) if err != nil { return Result{}, false, fmt.Errorf("failed to find submodule: %s", err) } if ok { - project := projectOrRemote(filepath.Base(gitdirSubmodule), g.ProjectFromGitRemote, gitdirSubmodule) + project := projectOrRemote(ctx, filepath.Base(gitdirSubmodule), g.ProjectFromGitRemote, gitdirSubmodule) // If submodule has a project map, then use it. - if result, ok := matchPattern(gitdirSubmodule, g.SubmoduleProjectMapPatterns); ok { + if result, ok := matchPattern(ctx, gitdirSubmodule, g.SubmoduleProjectMapPatterns); ok { project = result } - branch, err := findGitBranch(filepath.Join(gitdirSubmodule, "HEAD")) + branch, err := findGitBranch(ctx, filepath.Join(gitdirSubmodule, "HEAD")) if err != nil { - log.Errorf( + logger.Errorf( "error finding branch from %q: %s", filepath.Join(filepath.Dir(gitdirSubmodule), "HEAD"), err, @@ -62,13 +64,13 @@ func (g Git) Detect() (Result, bool, error) { } // Find for .git file or directory - dotGit, found := FindFileOrDirectory(fp, ".git") + dotGit, found := FindFileOrDirectory(ctx, fp, ".git") if !found { return Result{}, false, nil } // Find for gitdir path - gitdir, err := findGitdir(dotGit) + gitdir, err := findGitdir(ctx, dotGit) if err != nil { return Result{}, false, fmt.Errorf("error finding gitdir: %s", err) } @@ -77,7 +79,7 @@ func (g Git) Detect() (Result, bool, error) { // worktree is present but .git folder is not present. In that case, we need to find // for worktree folder. // Find for commondir file - commondir, ok, err := findCommondir(gitdir) + commondir, ok, err := findCommondir(ctx, gitdir) if err != nil { return Result{}, false, fmt.Errorf("error finding commondir: %s", err) } @@ -93,11 +95,11 @@ func (g Git) Detect() (Result, bool, error) { dir = commondir } - project := projectOrRemote(filepath.Base(dir), g.ProjectFromGitRemote, commondir) + project := projectOrRemote(ctx, filepath.Base(dir), g.ProjectFromGitRemote, commondir) - branch, err := findGitBranch(filepath.Join(gitdir, "HEAD")) + branch, err := findGitBranch(ctx, filepath.Join(gitdir, "HEAD")) if err != nil { - log.Errorf( + logger.Errorf( "error finding branch from %q: %s", filepath.Join(filepath.Dir(dotGit), "HEAD"), err, @@ -113,11 +115,11 @@ func (g Git) Detect() (Result, bool, error) { // Otherwise it's only a plain .git file and not a submodule if gitdir != "" && !strings.Contains(gitdir, "modules") { - project := projectOrRemote(filepath.Base(filepath.Join(dotGit, "..")), g.ProjectFromGitRemote, gitdir) + project := projectOrRemote(ctx, filepath.Base(filepath.Join(dotGit, "..")), g.ProjectFromGitRemote, gitdir) - branch, err := findGitBranch(filepath.Join(gitdir, "HEAD")) + branch, err := findGitBranch(ctx, filepath.Join(gitdir, "HEAD")) if err != nil { - log.Errorf( + logger.Errorf( "error finding branch from %q: %s", filepath.Join(filepath.Dir(gitdir), "HEAD"), err, @@ -132,22 +134,22 @@ func (g Git) Detect() (Result, bool, error) { } // Find for .git/config file - gitConfigFile, found := FindFileOrDirectory(fp, filepath.Join(".git", "config")) + gitConfigFile, found := FindFileOrDirectory(ctx, fp, filepath.Join(".git", "config")) if found { gitDir := filepath.Dir(gitConfigFile) projectDir := filepath.Join(gitDir, "..") - branch, err := findGitBranch(filepath.Join(gitDir, "HEAD")) + branch, err := findGitBranch(ctx, filepath.Join(gitDir, "HEAD")) if err != nil { - log.Errorf( + logger.Errorf( "error finding branch from %q: %s", filepath.Join(gitDir, "HEAD"), err, ) } - project := projectOrRemote(filepath.Base(projectDir), g.ProjectFromGitRemote, gitDir) + project := projectOrRemote(ctx, filepath.Base(projectDir), g.ProjectFromGitRemote, gitDir) return Result{ Project: project, @@ -159,17 +161,17 @@ func (g Git) Detect() (Result, bool, error) { return Result{}, false, nil } -func findSubmodule(fp string, patterns []regex.Regex) (string, bool, error) { - if !shouldTakeSubmodule(fp, patterns) { +func findSubmodule(ctx context.Context, fp string, patterns []regex.Regex) (string, bool, error) { + if !shouldTakeSubmodule(ctx, fp, patterns) { return "", false, nil } - gitConfigFile, found := FindFileOrDirectory(fp, ".git") + gitConfigFile, found := FindFileOrDirectory(ctx, fp, ".git") if !found { return "", false, nil } - gitdir, err := findGitdir(gitConfigFile) + gitdir, err := findGitdir(ctx, gitConfigFile) if err != nil { return "", false, fmt.Errorf("error finding gitdir for submodule: %s", err) @@ -184,9 +186,9 @@ func findSubmodule(fp string, patterns []regex.Regex) (string, bool, error) { // shouldTakeSubmodule checks a filepath against the passed in regex patterns to determine, // if submodule filepath should be taken. -func shouldTakeSubmodule(fp string, patterns []regex.Regex) bool { +func shouldTakeSubmodule(ctx context.Context, fp string, patterns []regex.Regex) bool { for _, p := range patterns { - if p.MatchString(fp) { + if p.MatchString(ctx, fp) { return false } } @@ -194,8 +196,8 @@ func shouldTakeSubmodule(fp string, patterns []regex.Regex) bool { return true } -func findGitdir(fp string) (string, error) { - lines, err := ReadFile(fp, 1) +func findGitdir(ctx context.Context, fp string) (string, error) { + lines, err := ReadFile(ctx, fp, 1) if err != nil { return "", fmt.Errorf("failed while opening file %q: %s", fp, err) } @@ -222,7 +224,7 @@ func resolveGitdir(fp, gitdir string) (string, error) { return "", nil } -func findCommondir(fp string) (string, bool, error) { +func findCommondir(ctx context.Context, fp string) (string, bool, error) { if fp == "" { return "", false, nil } @@ -232,14 +234,14 @@ func findCommondir(fp string) (string, bool, error) { } if fileOrDirExists(filepath.Join(fp, "commondir")) { - return resolveCommondir(fp) + return resolveCommondir(ctx, fp) } return "", false, nil } -func resolveCommondir(fp string) (string, bool, error) { - lines, err := ReadFile(filepath.Join(fp, "commondir"), 1) +func resolveCommondir(ctx context.Context, fp string) (string, bool, error) { + lines, err := ReadFile(ctx, filepath.Join(fp, "commondir"), 1) if err != nil { return "", false, fmt.Errorf("failed while opening file %q: %s", fp, err) @@ -258,16 +260,17 @@ func resolveCommondir(fp string) (string, bool, error) { return gitdir, true, nil } -func projectOrRemote(projectName string, projectFromGitRemote bool, dotGitFolder string) string { +func projectOrRemote(ctx context.Context, projectName string, projectFromGitRemote bool, dotGitFolder string) string { if !projectFromGitRemote { return projectName } + logger := log.Extract(ctx) configFile := filepath.Join(dotGitFolder, "config") - remote, err := findGitRemote(configFile) + remote, err := findGitRemote(ctx, configFile) if err != nil { - log.Errorf("error finding git remote from %q: %s", configFile, err) + logger.Errorf("error finding git remote from %q: %s", configFile, err) return projectName } @@ -279,20 +282,22 @@ func projectOrRemote(projectName string, projectFromGitRemote bool, dotGitFolder return projectName } -func findGitBranch(fp string) (string, error) { +func findGitBranch(ctx context.Context, fp string) (string, error) { if !fileOrDirExists(fp) { return "master", nil } - lines, err := ReadFile(fp, 1) + lines, err := ReadFile(ctx, fp, 1) if err != nil { return "", fmt.Errorf("failed while opening file %q: %s", fp, err) } + logger := log.Extract(ctx) + if len(lines) > 0 && strings.HasPrefix(strings.TrimSpace(lines[0]), "ref: ") { parts := strings.SplitN(lines[0], "/", 3) if len(parts) < 3 { - log.Warnf("invalid branch from %q: %s", fp, lines[0]) + logger.Warnf("invalid branch from %q: %s", fp, lines[0]) return "", nil } @@ -303,12 +308,12 @@ func findGitBranch(fp string) (string, error) { return "", nil } -func findGitRemote(fp string) (string, error) { +func findGitRemote(ctx context.Context, fp string) (string, error) { if !fileOrDirExists(fp) { return "", nil } - lines, err := ReadFile(fp, 1000) + lines, err := ReadFile(ctx, fp, 1000) if err != nil { return "", fmt.Errorf("failed while opening file %q: %s", fp, err) } diff --git a/pkg/project/git_test.go b/pkg/project/git_test.go index 7a295a07..e27ffac7 100644 --- a/pkg/project/git_test.go +++ b/pkg/project/git_test.go @@ -1,6 +1,7 @@ package project_test import ( + "context" "fmt" "os" "path/filepath" @@ -24,7 +25,7 @@ func TestGit_Detect(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/src/pkg/file.go"), } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -43,7 +44,7 @@ func TestGit_Detect_BranchWithSlash(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/src/pkg/file.go"), } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -62,7 +63,7 @@ func TestGit_Detect_DetachedHead(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/src/pkg/file.go"), } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -101,7 +102,7 @@ func TestGit_Detect_GitConfigFile_File(t *testing.T) { Filepath: test.Filepath, } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -126,7 +127,7 @@ func TestGit_Detect_GitConfigFile_File_MalformedHEAD(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/src/pkg/file.go"), } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -145,7 +146,7 @@ func TestGit_Detect_Worktree(t *testing.T) { Filepath: filepath.Join(fp, "api/src/pkg/file.go"), } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -165,7 +166,7 @@ func TestGit_Detect_WorktreeGitRemote(t *testing.T) { ProjectFromGitRemote: true, } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -184,7 +185,7 @@ func TestGit_Detect_Worktree_BareRepo(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/master/src/pkg/file.go"), } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -204,7 +205,7 @@ func TestGit_Detect_WorktreeGitRemote_BareRepo(t *testing.T) { ProjectFromGitRemote: true, } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -221,10 +222,10 @@ func TestGit_Detect_Submodule(t *testing.T) { g := project.Git{ Filepath: filepath.Join(fp, "wakatime-cli/lib/billing/src/lib/lib.cpp"), - SubmoduleDisabledPatterns: []regex.Regex{regexp.MustCompile("not_matching")}, + SubmoduleDisabledPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("not_matching"))}, } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -241,10 +242,10 @@ func TestGit_Detect_SubmoduleDisabled(t *testing.T) { g := project.Git{ Filepath: filepath.Join(fp, "wakatime-cli/lib/billing/src/lib/lib.cpp"), - SubmoduleDisabledPatterns: []regex.Regex{regexp.MustCompile(".*billing.*")}, + SubmoduleDisabledPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*billing.*"))}, } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -264,12 +265,12 @@ func TestGit_Detect_SubmoduleProjectMap_NotMatch(t *testing.T) { SubmoduleProjectMapPatterns: []project.MapPattern{ { Name: "my-project-1", - Regex: regexp.MustCompile(formatRegex("not_matching")), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex("not_matching"))), }, }, } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -289,12 +290,12 @@ func TestGit_Detect_SubmoduleProjectMap(t *testing.T) { SubmoduleProjectMapPatterns: []project.MapPattern{ { Name: "my-project-1", - Regex: regexp.MustCompile(formatRegex(".*billing.*")), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(".*billing.*"))), }, }, } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -312,10 +313,10 @@ func TestGit_Detect_SubmoduleGitRemote(t *testing.T) { g := project.Git{ Filepath: filepath.Join(fp, "wakatime-cli/lib/billing/src/lib/lib.cpp"), ProjectFromGitRemote: true, - SubmoduleDisabledPatterns: []regex.Regex{regexp.MustCompile("not_matching")}, + SubmoduleDisabledPatterns: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile("not_matching"))}, } - result, detected, err := g.Detect() + result, detected, err := g.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) diff --git a/pkg/project/map.go b/pkg/project/map.go index af6770b4..e0b91789 100644 --- a/pkg/project/map.go +++ b/pkg/project/map.go @@ -1,6 +1,7 @@ package project import ( + "context" "path/filepath" "github.com/wakatime/wakatime-cli/pkg/log" @@ -26,8 +27,8 @@ type Map struct { // Will result in file '/home/user/projects/foo/src/main.c' to have // project name 'new project name' and file '/home/user/projects/bar42/main.c' // to have project name 'project42'. -func (m Map) Detect() (Result, bool, error) { - result, ok := matchPattern(m.Filepath, m.Patterns) +func (m Map) Detect(ctx context.Context) (Result, bool, error) { + result, ok := matchPattern(ctx, m.Filepath, m.Patterns) if !ok { return Result{}, false, nil } @@ -39,10 +40,12 @@ func (m Map) Detect() (Result, bool, error) { } // matchPattern matches regex against entity's path to find project name. -func matchPattern(fp string, patterns []MapPattern) (string, bool) { +func matchPattern(ctx context.Context, fp string, patterns []MapPattern) (string, bool) { + logger := log.Extract(ctx) + for _, pattern := range patterns { - if pattern.Regex.MatchString(fp) { - matches := pattern.Regex.FindStringSubmatch(fp) + if pattern.Regex.MatchString(ctx, fp) { + matches := pattern.Regex.FindStringSubmatch(ctx, fp) if len(matches) > 0 { params := make([]any, len(matches[1:])) for i, v := range matches[1:] { @@ -51,7 +54,7 @@ func matchPattern(fp string, patterns []MapPattern) (string, bool) { result, err := pyfmt.Fmt(pattern.Name, params...) if err != nil { - log.Errorf("error formatting %q: %s", pattern.Name, err) + logger.Errorf("error formatting %q: %s", pattern.Name, err) continue } diff --git a/pkg/project/map_test.go b/pkg/project/map_test.go index 931dbc49..1e87f990 100644 --- a/pkg/project/map_test.go +++ b/pkg/project/map_test.go @@ -1,12 +1,14 @@ package project_test import ( + "context" "os" "path/filepath" "regexp" "testing" "github.com/wakatime/wakatime-cli/pkg/project" + "github.com/wakatime/wakatime-cli/pkg/regex" "github.com/gandarez/go-realpath" "github.com/stretchr/testify/assert" @@ -25,12 +27,12 @@ func TestMap_Detect(t *testing.T) { Patterns: []project.MapPattern{ { Name: "my-project-1", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, "testdata"))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, "testdata")))), }, }, } - result, detected, err := m.Detect() + result, detected, err := m.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -51,16 +53,16 @@ func TestMap_Detect_RegexReplace(t *testing.T) { Patterns: []project.MapPattern{ { Name: "my-project-1", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "otherfolder"))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "otherfolder")))), }, { Name: "my-project-2-{0}", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, `test([a-zA-Z]+)`))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, `test([a-zA-Z]+)`)))), }, }, } - result, detected, err := m.Detect() + result, detected, err := m.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -78,16 +80,16 @@ func TestMap_Detect_NoMatch(t *testing.T) { Patterns: []project.MapPattern{ { Name: "my_project_1", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "otherfolder"))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "otherfolder")))), }, { Name: "my_project_2", - Regex: regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "temp"))), + Regex: regex.NewRegexpWrap(regexp.MustCompile(formatRegex(filepath.Join(wd, "path", "to", "temp")))), }, }, } - result, detected, err := m.Detect() + result, detected, err := m.Detect(context.Background()) require.NoError(t, err) assert.False(t, detected) @@ -101,7 +103,7 @@ func TestMap_Detect_ZeroPatterns(t *testing.T) { Patterns: []project.MapPattern{}, } - _, detected, err := m.Detect() + _, detected, err := m.Detect(context.Background()) require.NoError(t, err) assert.False(t, detected) diff --git a/pkg/project/mercurial.go b/pkg/project/mercurial.go index c3402f03..06d0f8bd 100644 --- a/pkg/project/mercurial.go +++ b/pkg/project/mercurial.go @@ -1,6 +1,7 @@ package project import ( + "context" "fmt" "path/filepath" "strings" @@ -15,7 +16,7 @@ type Mercurial struct { } // Detect gets information about the mercurial project for a given file. -func (m Mercurial) Detect() (Result, bool, error) { +func (m Mercurial) Detect(ctx context.Context) (Result, bool, error) { var fp string // Take only the directory @@ -24,16 +25,17 @@ func (m Mercurial) Detect() (Result, bool, error) { } // Find for .hg folder - hgDirectory, found := FindFileOrDirectory(fp, ".hg") + hgDirectory, found := FindFileOrDirectory(ctx, fp, ".hg") if !found { return Result{}, false, nil } + logger := log.Extract(ctx) project := filepath.Base(filepath.Dir(hgDirectory)) - branch, err := findHgBranch(hgDirectory) + branch, err := findHgBranch(ctx, hgDirectory) if err != nil { - log.Errorf( + logger.Errorf( "error finding for branch name from %q: %s", hgDirectory, err, @@ -47,13 +49,13 @@ func (m Mercurial) Detect() (Result, bool, error) { }, true, nil } -func findHgBranch(fp string) (string, error) { +func findHgBranch(ctx context.Context, fp string) (string, error) { p := filepath.Join(fp, "branch") if !fileOrDirExists(p) { return "default", nil } - lines, err := ReadFile(p, 1) + lines, err := ReadFile(ctx, p, 1) if err != nil { return "", fmt.Errorf("failed while opening file %q: %s", fp, err) } diff --git a/pkg/project/mercurial_test.go b/pkg/project/mercurial_test.go index 686e9837..92efabf6 100644 --- a/pkg/project/mercurial_test.go +++ b/pkg/project/mercurial_test.go @@ -1,6 +1,7 @@ package project_test import ( + "context" "os" "path/filepath" "testing" @@ -18,7 +19,7 @@ func TestMercurial_Detect(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/src/pkg/file.go"), } - result, detected, err := m.Detect() + result, detected, err := m.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -37,7 +38,7 @@ func TestMercurial_Detect_BranchWithSlash(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/src/pkg/file.go"), } - result, detected, err := m.Detect() + result, detected, err := m.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -56,7 +57,7 @@ func TestMercurial_Detect_NoBranch(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/src/pkg/file.go"), } - result, detected, err := m.Detect() + result, detected, err := m.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) diff --git a/pkg/project/project.go b/pkg/project/project.go index 4f9a42f2..5967ead5 100644 --- a/pkg/project/project.go +++ b/pkg/project/project.go @@ -1,6 +1,7 @@ package project import ( + "context" "fmt" "math/rand" "os" @@ -85,7 +86,7 @@ func (d DetectorID) String() string { type ( // Detecter is a common interface for project. Detecter interface { - Detect() (Result, bool, error) + Detect(context.Context) (Result, bool, error) ID() DetectorID } @@ -138,14 +139,16 @@ type ( // Last, uses the --alternate-project arg. func WithDetection(config Config) heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + for n, h := range hh { - log.Debugln("execute project detection for:", h.Entity) + logger.Debugln("execute project detection for:", h.Entity) // first, use .wakatime-project or [projectmap] section with entity path. // Then, detect with project folder. This tries to use the same project name // across all IDEs instead of sometimes using alternate project when file is unsaved - result, detector := Detect(config.MapPatterns, + result, detector := Detect(ctx, config.MapPatterns, DetecterArg{Filepath: h.Entity, ShouldRun: h.EntityType == heartbeat.FileType}, DetecterArg{Filepath: h.ProjectPathOverride, ShouldRun: true}, ) @@ -161,6 +164,7 @@ func WithDetection(config Config) heartbeat.HandleOption { // across all IDEs instead of sometimes using alternate project when file is unsaved if result.Project == "" || result.Branch == "" || result.Folder == "" { revControlResult := DetectWithRevControl( + ctx, config.Submodule.DisabledPatterns, config.Submodule.MapPatterns, config.ProjectFromGitRemote, @@ -197,12 +201,12 @@ func WithDetection(config Config) heartbeat.HandleOption { } // finally, obfuscate project name if necessary - if heartbeat.ShouldSanitize(result.Folder, config.HideProjectNames) && + if heartbeat.ShouldSanitize(ctx, result.Folder, config.HideProjectNames) && result.Project != "" && detector != FileDetector { - result.Project = obfuscateProjectName(result.Folder) + result.Project = obfuscateProjectName(ctx, result.Folder) } - result.Folder = FormatProjectFolder(result.Folder) + result.Folder = FormatProjectFolder(ctx, result.Folder) // count total subfolders in project's path if result.Folder != "" && strings.HasPrefix(h.Entity, result.Folder) { @@ -217,13 +221,15 @@ func WithDetection(config Config) heartbeat.HandleOption { hh[n].ProjectPath = result.Folder } - return next(hh) + return next(ctx, hh) } } } // Detect finds the current project and branch from config plugins. -func Detect(patterns []MapPattern, args ...DetecterArg) (Result, DetectorID) { +func Detect(ctx context.Context, patterns []MapPattern, args ...DetecterArg) (Result, DetectorID) { + logger := log.Extract(ctx) + for _, arg := range args { if !arg.ShouldRun || arg.Filepath == "" { continue @@ -240,11 +246,11 @@ func Detect(patterns []MapPattern, args ...DetecterArg) (Result, DetectorID) { } for _, p := range configPlugins { - log.Debugln("execute", p.ID().String()) + logger.Debugln("execute", p.ID().String()) - result, detected, err := p.Detect() + result, detected, err := p.Detect(ctx) if err != nil { - log.Errorf("unexpected error occurred at %q: %s", p.ID().String(), err) + logger.Errorf("unexpected error occurred at %q: %s", p.ID().String(), err) continue } @@ -259,10 +265,13 @@ func Detect(patterns []MapPattern, args ...DetecterArg) (Result, DetectorID) { // DetectWithRevControl finds the current project and branch from rev control. func DetectWithRevControl( + ctx context.Context, submoduleDisabledPatterns []regex.Regex, submoduleProjectMapPatterns []MapPattern, projectFromGitRemote bool, args ...DetecterArg) Result { + logger := log.Extract(ctx) + for _, arg := range args { if !arg.ShouldRun || arg.Filepath == "" { continue @@ -287,11 +296,11 @@ func DetectWithRevControl( } for _, p := range revControlPlugins { - log.Debugln("execute", p.ID().String()) + logger.Debugln("execute", p.ID().String()) - result, detected, err := p.Detect() + result, detected, err := p.Detect(ctx) if err != nil { - log.Errorf("unexpected error occurred at %q: %s", p.ID().String(), err) + logger.Errorf("unexpected error occurred at %q: %s", p.ID().String(), err) continue } @@ -308,17 +317,18 @@ func DetectWithRevControl( return Result{} } -func obfuscateProjectName(folder string) string { +func obfuscateProjectName(ctx context.Context, folder string) string { // prevent overwriting existing project files, use Unknown Project instead if fileOrDirExists(filepath.Join(folder, WakaTimeProjectFile)) { return "" } + logger := log.Extract(ctx) project := generateProjectName() err := Write(folder, project) if err != nil { - log.Warnf("failed to write: %s", err) + logger.Warnf("failed to write: %s", err) } return project @@ -628,7 +638,7 @@ func CountSlashesInProjectFolder(directory string) int { // FindFileOrDirectory searches current and all parent folders for a file or directory named `filename`. // Starts in `directory` and traverses through all parent directories. // `directory` may also be a file, and in that case will start from the file's directory. -func FindFileOrDirectory(directory, filename string) (string, bool) { +func FindFileOrDirectory(ctx context.Context, directory, filename string) (string, bool) { i := 0 for i < maxRecursiveIteration { if isRootPath(directory) { @@ -644,7 +654,8 @@ func FindFileOrDirectory(directory, filename string) (string, bool) { i++ } - log.Warnf("max %d iterations reached without finding %s", maxRecursiveIteration, filename) + logger := log.Extract(ctx) + logger.Warnf("max %d iterations reached without finding %s", maxRecursiveIteration, filename) return "", false } @@ -671,7 +682,7 @@ func firstNonEmptyString(values ...string) string { } // FormatProjectFolder returns the abs and real path for the given directory path. -func FormatProjectFolder(fp string) string { +func FormatProjectFolder(ctx context.Context, fp string) string { if fp == "" { return "" } @@ -680,16 +691,18 @@ func FormatProjectFolder(fp string) string { return windows.FormatFilePath(fp) } + logger := log.Extract(ctx) + formatted, err := filepath.Abs(fp) if err != nil { - log.Debugf("failed to resolve absolute path for %q: %s", fp, err) + logger.Debugf("failed to resolve absolute path for %q: %s", fp, err) return formatted } // evaluate any symlinks formatted, err = realpath.Realpath(formatted) if err != nil { - log.Debugf("failed to resolve real path for %q: %s", formatted, err) + logger.Debugf("failed to resolve real path for %q: %s", formatted, err) } return formatted diff --git a/pkg/project/project_test.go b/pkg/project/project_test.go index 5ce512d5..8649a764 100644 --- a/pkg/project/project_test.go +++ b/pkg/project/project_test.go @@ -1,6 +1,7 @@ package project_test import ( + "context" "os" "path/filepath" "runtime" @@ -18,6 +19,8 @@ import ( ) func TestWithDetection_EntityNotFile(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { Heartbeats []heartbeat.Heartbeat Override string @@ -72,7 +75,7 @@ func TestWithDetection_EntityNotFile(t *testing.T) { t.Run(name, func(t *testing.T) { opt := project.WithDetection(project.Config{}) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ test.Expected, }, hh) @@ -80,7 +83,7 @@ func TestWithDetection_EntityNotFile(t *testing.T) { return nil, nil }) - _, err := handle(test.Heartbeats) + _, err := handle(ctx, test.Heartbeats) require.NoError(t, err) }) } @@ -89,9 +92,11 @@ func TestWithDetection_EntityNotFile(t *testing.T) { func TestWithDetection_WakatimeProjectTakesPrecedence(t *testing.T) { fp := setupTestGitBasic(t) + ctx := context.Background() + entity := filepath.Join(fp, "wakatime-cli/src/pkg/file.go") projectPath := filepath.Join(fp, "wakatime-cli") - projectPath = project.FormatProjectFolder(projectPath) + projectPath = project.FormatProjectFolder(ctx, projectPath) if runtime.GOOS == "windows" { entity = windows.FormatFilePath(entity) @@ -111,7 +116,7 @@ func TestWithDetection_WakatimeProjectTakesPrecedence(t *testing.T) { } sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.NotEmpty(t, hh[0].Project) assert.Equal(t, []heartbeat.Heartbeat{ { @@ -131,7 +136,7 @@ func TestWithDetection_WakatimeProjectTakesPrecedence(t *testing.T) { handle := heartbeat.NewHandle(&sender, opts...) - _, err := handle([]heartbeat.Heartbeat{ + _, err := handle(ctx, []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: entity, @@ -146,7 +151,7 @@ func TestWithDetection_OverrideTakesPrecedence(t *testing.T) { entity := filepath.Join(fp, "wakatime-cli/src/pkg/file.go") projectPath := filepath.Join(fp, "wakatime-cli") - projectPath = project.FormatProjectFolder(projectPath) + projectPath = project.FormatProjectFolder(context.Background(), projectPath) if runtime.GOOS == "windows" { entity = windows.FormatFilePath(entity) @@ -154,7 +159,7 @@ func TestWithDetection_OverrideTakesPrecedence(t *testing.T) { opt := project.WithDetection(project.Config{}) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo("master"), @@ -170,7 +175,7 @@ func TestWithDetection_OverrideTakesPrecedence(t *testing.T) { return nil, nil }) - _, err := handle([]heartbeat.Heartbeat{ + _, err := handle(context.Background(), []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: entity, @@ -191,7 +196,7 @@ func TestWithDetection_OverrideTakesPrecedence_WithProjectPathOverride(t *testin opt := project.WithDetection(project.Config{}) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo("master"), @@ -208,7 +213,7 @@ func TestWithDetection_OverrideTakesPrecedence_WithProjectPathOverride(t *testin return nil, nil }) - _, err := handle([]heartbeat.Heartbeat{ + _, err := handle(context.Background(), []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: entity, @@ -225,10 +230,12 @@ func TestWithDetection_NoneDetected(t *testing.T) { defer tmpFile.Close() + ctx := context.Background() + entity := tmpFile.Name() projectPath := filepath.Dir(tmpFile.Name()) - projectPath = project.FormatProjectFolder(projectPath) + projectPath = project.FormatProjectFolder(ctx, projectPath) if runtime.GOOS == "windows" { entity = windows.FormatFilePath(entity) @@ -243,7 +250,7 @@ func TestWithDetection_NoneDetected(t *testing.T) { } sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo(""), @@ -261,7 +268,7 @@ func TestWithDetection_NoneDetected(t *testing.T) { handle := heartbeat.NewHandle(&sender, opts...) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(ctx, []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: tmpFile.Name(), @@ -276,10 +283,12 @@ func TestWithDetection_NoneDetected_AlternateTakesPrecedence(t *testing.T) { defer tmpFile.Close() + ctx := context.Background() + entity := tmpFile.Name() projectPath := filepath.Dir(tmpFile.Name()) - projectPath = project.FormatProjectFolder(projectPath) + projectPath = project.FormatProjectFolder(ctx, projectPath) if runtime.GOOS == "windows" { entity = windows.FormatFilePath(entity) @@ -294,7 +303,7 @@ func TestWithDetection_NoneDetected_AlternateTakesPrecedence(t *testing.T) { } sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo("alternate-branch"), @@ -314,7 +323,7 @@ func TestWithDetection_NoneDetected_AlternateTakesPrecedence(t *testing.T) { handle := heartbeat.NewHandle(&sender, opts...) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(ctx, []heartbeat.Heartbeat{ { BranchAlternate: "alternate-branch", EntityType: heartbeat.FileType, @@ -331,6 +340,8 @@ func TestWithDetection_NoneDetected_OverrideTakesPrecedence(t *testing.T) { defer tmpFile.Close() + ctx := context.Background() + entity := tmpFile.Name() if runtime.GOOS == "windows" { @@ -341,7 +352,7 @@ func TestWithDetection_NoneDetected_OverrideTakesPrecedence(t *testing.T) { } projectPath := filepath.Dir(tmpFile.Name()) - projectPath = project.FormatProjectFolder(projectPath) + projectPath = project.FormatProjectFolder(ctx, projectPath) opts := []heartbeat.HandleOption{ heartbeat.WithFormatting(), @@ -349,7 +360,7 @@ func TestWithDetection_NoneDetected_OverrideTakesPrecedence(t *testing.T) { } sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo(""), @@ -368,7 +379,7 @@ func TestWithDetection_NoneDetected_OverrideTakesPrecedence(t *testing.T) { handle := heartbeat.NewHandle(&sender, opts...) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(ctx, []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: tmpFile.Name(), @@ -385,6 +396,8 @@ func TestWithDetection_NoneDetected_WithProjectPathOverride(t *testing.T) { defer tmpFile.Close() + ctx := context.Background() + opts := []heartbeat.HandleOption{ heartbeat.WithFormatting(), project.WithDetection(project.Config{}), @@ -399,10 +412,10 @@ func TestWithDetection_NoneDetected_WithProjectPathOverride(t *testing.T) { require.NoError(t, err) } - projectFolder := project.FormatProjectFolder(tmpDir) + projectFolder := project.FormatProjectFolder(ctx, tmpDir) sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Branch: heartbeat.PointerTo(""), @@ -422,7 +435,7 @@ func TestWithDetection_NoneDetected_WithProjectPathOverride(t *testing.T) { handle := heartbeat.NewHandle(&sender, opts...) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(ctx, []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: tmpFile.Name(), @@ -436,9 +449,11 @@ func TestWithDetection_NoneDetected_WithProjectPathOverride(t *testing.T) { func TestWithDetection_ObfuscateProject(t *testing.T) { fp := setupTestGitBasic(t) + ctx := context.Background() + entity := filepath.Join(fp, "wakatime-cli/src/pkg/file.go") projectPath := filepath.Join(fp, "wakatime-cli") - projectPath = project.FormatProjectFolder(projectPath) + projectPath = project.FormatProjectFolder(ctx, projectPath) if runtime.GOOS == "windows" { entity = windows.FormatFilePath(entity) @@ -448,7 +463,7 @@ func TestWithDetection_ObfuscateProject(t *testing.T) { HideProjectNames: []regex.Regex{regex.MustCompile(".*")}, }) - handle := opt(func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + handle := opt(func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.NotEmpty(t, hh[0].Project) assert.Equal(t, []heartbeat.Heartbeat{ { @@ -464,7 +479,7 @@ func TestWithDetection_ObfuscateProject(t *testing.T) { return nil, nil }) - _, err := handle([]heartbeat.Heartbeat{ + _, err := handle(ctx, []heartbeat.Heartbeat{ { EntityType: heartbeat.FileType, Entity: entity, @@ -491,7 +506,7 @@ func TestDetect_FileDetected(t *testing.T) { filepath.Join(tmpDir, "entity.any"), ) - result, detector := project.Detect([]project.MapPattern{}, project.DetecterArg{ + result, detector := project.Detect(context.Background(), []project.MapPattern{}, project.DetecterArg{ Filepath: filepath.Join(tmpDir, "entity.any"), ShouldRun: true, }) @@ -520,7 +535,7 @@ func TestDetect_EmptyFileDetected(t *testing.T) { filepath.Join(tmpDir, "wakatime-cli", "entity.any"), ) - result, detector := project.Detect([]project.MapPattern{}, project.DetecterArg{ + result, detector := project.Detect(context.Background(), []project.MapPattern{}, project.DetecterArg{ Filepath: filepath.Join(tmpDir, "wakatime-cli", "entity.any"), ShouldRun: true, }) @@ -550,7 +565,7 @@ func TestDetect_MapDetected(t *testing.T) { }, } - result, detector := project.Detect(patterns, project.DetecterArg{ + result, detector := project.Detect(context.Background(), patterns, project.DetecterArg{ Filepath: tmpFile.Name(), ShouldRun: true, }) @@ -565,6 +580,7 @@ func TestDetectWithRevControl_GitDetected(t *testing.T) { fp := setupTestGitBasic(t) result := project.DetectWithRevControl( + context.Background(), []regex.Regex{}, []project.MapPattern{}, false, @@ -586,6 +602,7 @@ func TestDetectWithRevControl_GitRemoteDetected(t *testing.T) { fp := setupTestGitBasic(t) result := project.DetectWithRevControl( + context.Background(), []regex.Regex{}, []project.MapPattern{}, true, @@ -609,7 +626,7 @@ func TestDetect_NoProjectDetected(t *testing.T) { defer tmpFile.Close() - result, detector := project.Detect([]project.MapPattern{}, project.DetecterArg{ + result, detector := project.Detect(context.Background(), []project.MapPattern{}, project.DetecterArg{ Filepath: tmpFile.Name(), ShouldRun: true, }) @@ -705,11 +722,11 @@ func formatRegex(fp string) string { } type mockSender struct { - SendHeartbeatsFn func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) + SendHeartbeatsFn func(context.Context, []heartbeat.Heartbeat) ([]heartbeat.Result, error) SendHeartbeatsFnInvoked bool } -func (m *mockSender) SendHeartbeats(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { +func (m *mockSender) SendHeartbeats(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { m.SendHeartbeatsFnInvoked = true - return m.SendHeartbeatsFn(hh) + return m.SendHeartbeatsFn(ctx, hh) } diff --git a/pkg/project/subversion.go b/pkg/project/subversion.go index 67955721..b0713670 100644 --- a/pkg/project/subversion.go +++ b/pkg/project/subversion.go @@ -1,6 +1,7 @@ package project import ( + "context" "fmt" "os/exec" "path/filepath" @@ -17,10 +18,12 @@ type Subversion struct { } // Detect gets information about the svn project for a given file. -func (s Subversion) Detect() (Result, bool, error) { - binary, ok := findSvnBinary() +func (s Subversion) Detect(ctx context.Context) (Result, bool, error) { + logger := log.Extract(ctx) + + binary, ok := findSvnBinary(ctx) if !ok { - log.Debugln("svn binary not found") + logger.Debugln("svn binary not found") return Result{}, false, nil } @@ -32,7 +35,7 @@ func (s Subversion) Detect() (Result, bool, error) { } // Find for .svn/wc.db file - svnConfigFile, found := FindFileOrDirectory(fp, filepath.Join(".svn", "wc.db")) + svnConfigFile, found := FindFileOrDirectory(ctx, fp, filepath.Join(".svn", "wc.db")) if !found { return Result{}, false, nil } @@ -77,19 +80,21 @@ func svnInfo(fp string, binary string) (map[string]string, bool, error) { return result, true, nil } -func findSvnBinary() (string, bool) { +func findSvnBinary(ctx context.Context) (string, bool) { locations := []string{ "svn", "/usr/bin/svn", "/usr/local/bin/svn", } + logger := log.Extract(ctx) + for _, loc := range locations { cmd := exec.Command(loc, "--version") // nolint:gosec err := cmd.Run() if err != nil { - log.Debugf("failed while calling %s --version: %s", loc, err) + logger.Debugf("failed while calling %s --version: %s", loc, err) continue } diff --git a/pkg/project/subversion_test.go b/pkg/project/subversion_test.go index 5af76e2a..b95f0e98 100644 --- a/pkg/project/subversion_test.go +++ b/pkg/project/subversion_test.go @@ -1,6 +1,7 @@ package project_test import ( + "context" "os" "os/exec" "path/filepath" @@ -21,7 +22,7 @@ func TestSubversion_Detect(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli", "src", "pkg", "file.go"), } - result, detected, err := s.Detect() + result, detected, err := s.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -41,7 +42,7 @@ func TestSubversion_Detect_Branch(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli/src/pkg/file.go"), } - result, detected, err := s.Detect() + result, detected, err := s.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) diff --git a/pkg/project/tfvc.go b/pkg/project/tfvc.go index 521929e7..cae8b187 100644 --- a/pkg/project/tfvc.go +++ b/pkg/project/tfvc.go @@ -1,6 +1,7 @@ package project import ( + "context" "path/filepath" "runtime" ) @@ -12,7 +13,7 @@ type Tfvc struct { } // Detect gets information about the tfvc project for a given file. -func (t Tfvc) Detect() (Result, bool, error) { +func (t Tfvc) Detect(ctx context.Context) (Result, bool, error) { var fp string // Take only the directory @@ -26,7 +27,7 @@ func (t Tfvc) Detect() (Result, bool, error) { } // Find for tf/properties.tf1 file - tfDirectory, found := FindFileOrDirectory(fp, filepath.Join(tfFolderName, "properties.tf1")) + tfDirectory, found := FindFileOrDirectory(ctx, fp, filepath.Join(tfFolderName, "properties.tf1")) if !found { return Result{}, false, nil } diff --git a/pkg/project/tfvc_test.go b/pkg/project/tfvc_test.go index 6e79b9a5..81fb3229 100644 --- a/pkg/project/tfvc_test.go +++ b/pkg/project/tfvc_test.go @@ -1,6 +1,7 @@ package project_test import ( + "context" "fmt" "os" "path/filepath" @@ -24,7 +25,7 @@ func TestTfvc_Detect(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli", "src", "pkg", "file.go"), } - result, detected, err := s.Detect() + result, detected, err := s.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) @@ -46,7 +47,7 @@ func TestTfvc_Detect_Windows(t *testing.T) { Filepath: filepath.Join(fp, "wakatime-cli", "src", "pkg", "file.go"), } - result, detected, err := s.Detect() + result, detected, err := s.Detect(context.Background()) require.NoError(t, err) assert.True(t, detected) diff --git a/pkg/regex/regex.go b/pkg/regex/regex.go index de7d3d6c..85997965 100644 --- a/pkg/regex/regex.go +++ b/pkg/regex/regex.go @@ -1,6 +1,7 @@ package regex import ( + "context" "fmt" "regexp" @@ -11,8 +12,8 @@ import ( // Regex interface to use regexp.Regexp and regexp2.Regexp interchangeably. type Regex interface { - FindStringSubmatch(s string) []string - MatchString(s string) bool + FindStringSubmatch(ctx context.Context, s string) []string + MatchString(ctx context.Context, s string) bool String() string } @@ -21,7 +22,9 @@ type Regex interface { func Compile(s string) (Regex, error) { r, err := regexp.Compile(s) if err == nil { - return r, nil + return &RegexpWrap{ + rgx: r, + }, nil } r2, err := regexp2.Compile(s, 0) @@ -46,6 +49,39 @@ func MustCompile(s string) Regex { return r } +// RegexpWrap is a wrapper around regexp.Regexp, which conforms to regexp.Regexp interface. +// Only supports a subset of methods. +type RegexpWrap struct { + rgx *regexp.Regexp +} + +// NewRegexpWrap returns a new instance of regexpWrap. +func NewRegexpWrap(rgx *regexp.Regexp) *RegexpWrap { + return &RegexpWrap{ + rgx: rgx, + } +} + +// FindStringSubmatch returns a slice of strings holding the text of the +// leftmost match of the regular expression in s and the matches, if any, of +// its subexpressions, as defined by the 'Submatch' description in the +// package comment. +// A return value of nil indicates no match. +func (re *RegexpWrap) FindStringSubmatch(_ context.Context, s string) []string { + return re.rgx.FindStringSubmatch(s) +} + +// MatchString reports whether the string s +// contains any match of the regular expression re. +func (re *RegexpWrap) MatchString(_ context.Context, s string) bool { + return re.rgx.MatchString(s) +} + +// String returns the source text used to compile the regular expression. +func (re *RegexpWrap) String() string { + return re.rgx.String() +} + // regexp2Wrap is a wrapper around github.com/dlclark/regexp2.Regexp, which conforms // to regexp.Regexp interface. Only supports a subset of methods. type regexp2Wrap struct { @@ -56,10 +92,12 @@ type regexp2Wrap struct { // match of the regular expression in s and the matches, if any, of its // subexpressions, as defined by the 'Submatch' description in the package comment. // A return value of nil indicates no match. -func (re *regexp2Wrap) FindStringSubmatch(s string) []string { +func (re *regexp2Wrap) FindStringSubmatch(ctx context.Context, s string) []string { + logger := log.Extract(ctx) + m, err := re.rgx.FindStringMatch(s) if err != nil { - log.Warnf("failed to find string match %q: %s", s, err) + logger.Warnf("failed to find string match %q: %s", s, err) return nil } @@ -80,10 +118,12 @@ func (re *regexp2Wrap) FindStringSubmatch(s string) []string { // MatchString reports whether the string s contains any match of the regular // expression re. -func (re *regexp2Wrap) MatchString(s string) bool { +func (re *regexp2Wrap) MatchString(ctx context.Context, s string) bool { + logger := log.Extract(ctx) + matched, err := re.rgx.MatchString(s) if err != nil { - log.Warnf("failed to match string %q: %s", s, err) + logger.Warnf("failed to match string %q: %s", s, err) return false } diff --git a/pkg/regex/regex_internal_test.go b/pkg/regex/regex_internal_test.go index 36a16b7a..11669856 100644 --- a/pkg/regex/regex_internal_test.go +++ b/pkg/regex/regex_internal_test.go @@ -1,6 +1,7 @@ package regex import ( + "context" "testing" "github.com/dlclark/regexp2" @@ -10,6 +11,8 @@ import ( ) func TestRegexp2Wrap_MatchString(t *testing.T) { + ctx := context.Background() + tests := map[string]bool{ "gopher": false, "gophergopher": true, @@ -25,12 +28,14 @@ func TestRegexp2Wrap_MatchString(t *testing.T) { rgx: r2, } - assert.Equal(t, expected, r.MatchString(str)) + assert.Equal(t, expected, r.MatchString(ctx, str)) }) } } func TestRegexp2Wrap_FindStringSubmatch(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { String string Expected []string @@ -54,7 +59,7 @@ func TestRegexp2Wrap_FindStringSubmatch(t *testing.T) { rgx: r2, } - matches := r.FindStringSubmatch(test.String) + matches := r.FindStringSubmatch(ctx, test.String) assert.Equal(t, test.Expected, matches) }) diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index bac4c747..044e489e 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -49,8 +49,9 @@ type Client struct { // download to a temporary directory. func WithDetection() heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute remote file detection") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute remote file detection") var filtered []heartbeat.Heartbeat @@ -62,29 +63,29 @@ func WithDetection() heartbeat.HandleOption { tmpFile, err := os.CreateTemp("", fmt.Sprintf("*_%s", filepath.Base(h.Entity))) if err != nil { - log.Errorf("failed to create temporary file: %s", err) + logger.Errorf("failed to create temporary file: %s", err) continue } - c, err := NewClient(h.Entity) + c, err := NewClient(ctx, h.Entity) if err != nil { - log.Errorf("failed to create new remote client: %s", err) + logger.Errorf("failed to create new remote client: %s", err) - deleteLocalFile(tmpFile.Name()) + deleteLocalFile(ctx, tmpFile.Name()) continue } - err = c.DownloadFile(tmpFile.Name()) + err = c.DownloadFile(ctx, tmpFile.Name()) if err != nil { - log.Errorf("failed to download file to temporary folder: %s", err) + logger.Errorf("failed to download file to temporary folder: %s", err) - err = c.DownloadFileFallback(tmpFile.Name()) + err = c.DownloadFileFallback(ctx, tmpFile.Name()) if err != nil { - log.Errorf("failed to download remote file using fallback option: %s", err) + logger.Errorf("failed to download remote file using fallback option: %s", err) } - deleteLocalFile(tmpFile.Name()) + deleteLocalFile(ctx, tmpFile.Name()) continue } @@ -95,7 +96,7 @@ func WithDetection() heartbeat.HandleOption { filtered = append(filtered, h) } - return next(filtered) + return next(ctx, filtered) } } } @@ -104,31 +105,34 @@ func WithDetection() heartbeat.HandleOption { // deletes a local temporary file if downloaded from a remote file. func WithCleanup() heartbeat.HandleOption { return func(next heartbeat.Handle) heartbeat.Handle { - return func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { - log.Debugln("execute remote cleanup") + return func(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + logger := log.Extract(ctx) + logger.Debugln("execute remote cleanup") for _, h := range hh { if h.LocalFileNeedsCleanup { - log.Debugln("deleting temporary file:", h.LocalFile) + logger.Debugln("deleting temporary file:", h.LocalFile) - deleteLocalFile(h.LocalFile) + deleteLocalFile(ctx, h.LocalFile) } } - return next(hh) + return next(ctx, hh) } } } -func deleteLocalFile(fp string) { +func deleteLocalFile(ctx context.Context, fp string) { + logger := log.Extract(ctx) + err := os.Remove(fp) if err != nil { - log.Warnf("unable to delete tmp file: %s", err) + logger.Warnf("unable to delete tmp file: %s", err) } } // NewClient initializes a new remote client. -func NewClient(address string) (Client, error) { +func NewClient(ctx context.Context, address string) (Client, error) { parsedURL, err := url.Parse(address) if err != nil { return Client{}, fmt.Errorf("failed to parse remote file url: %s", err) @@ -150,14 +154,16 @@ func NewClient(address string) (Client, error) { derivedHost = host } + logger := log.Extract(ctx) + if port == 0 { port, err = strconv.Atoi(ssh_config.Get(host, "Port")) - log.Warnf("failed to parse port from host: %s", err) + logger.Warnf("failed to parse port from host: %s", err) } if port == 0 { port, err = strconv.Atoi(ssh_config.Get(derivedHost, "Port")) - log.Warnf("failed to parse port from derived host: %s", err) + logger.Warnf("failed to parse port from derived host: %s", err) } if port == 0 { @@ -167,7 +173,7 @@ func NewClient(address string) (Client, error) { return Client{ User: parsedURL.User.Username(), Pass: pass, - HostKeyAlias: hostKeyAlias(host, derivedHost), + HostKeyAlias: hostKeyAlias(ctx, host, derivedHost), OriginalHost: host, Host: derivedHost, Port: port, @@ -176,19 +182,21 @@ func NewClient(address string) (Client, error) { } // DownloadFile downloads a remote file and copy to a local file. -func (c Client) DownloadFile(localFile string) error { - conn, sc, err := c.Connect() +func (c Client) DownloadFile(ctx context.Context, localFile string) error { + conn, sc, err := c.Connect(ctx) if err != nil { return fmt.Errorf("failed to connect to sftp host: %s", err) } + logger := log.Extract(ctx) + defer func() { if err := conn.Close(); err != nil { - log.Debugf("failed to close connection to ssh server: %s", err) + logger.Debugf("failed to close connection to ssh server: %s", err) } if err := sc.Close(); err != nil { - log.Debugf("failed to close connection to ftp server: %s", err) + logger.Debugf("failed to close connection to ftp server: %s", err) } }() @@ -199,7 +207,7 @@ func (c Client) DownloadFile(localFile string) error { defer func() { if err := srcFile.Close(); err != nil { - log.Debugf("failed to close remote ftp file: %s", err) + logger.Debugf("failed to close remote ftp file: %s", err) } }() @@ -210,7 +218,7 @@ func (c Client) DownloadFile(localFile string) error { defer func() { if err := dstFile.Close(); err != nil { - log.Warnf("failed to close local file: %s", err) + logger.Warnf("failed to close local file: %s", err) } }() @@ -223,11 +231,13 @@ func (c Client) DownloadFile(localFile string) error { } // DownloadFileFallback downloads a remote file and copy to a local file using machine's ssh. -func (c Client) DownloadFileFallback(localFile string) error { +func (c Client) DownloadFileFallback(ctx context.Context, localFile string) error { + logger := log.Extract(ctx) + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutSecs*time.Second) defer cancel() - log.Debugln("downloading remote file using fallback option") + logger.Debugln("downloading remote file using fallback option") cmd := exec.CommandContext(ctx, "scp", "-B", fmt.Sprintf("%s:%s", c.OriginalHost, c.Path), localFile) // nolint:gosec @@ -242,9 +252,9 @@ func (c Client) DownloadFileFallback(localFile string) error { } // Connect connects to sftp host. -func (c Client) Connect() (*ssh.Client, *sftp.Client, error) { +func (c Client) Connect(ctx context.Context) (*ssh.Client, *sftp.Client, error) { // Initialize client configuration - sshClient, err := c.sshClient() + sshClient, err := c.sshClient(ctx) if err != nil { return nil, nil, err } @@ -259,7 +269,8 @@ func (c Client) Connect() (*ssh.Client, *sftp.Client, error) { } // knownHostKeys gets all host keys from local known hosts for given hosts. -func (c Client) knownHostKeys() []ssh.PublicKey { +func (c Client) knownHostKeys(ctx context.Context) []ssh.PublicKey { + logger := log.Extract(ctx) hostKeys := []ssh.PublicKey{} filenames := c.knownHostsFiles() @@ -273,7 +284,7 @@ func (c Client) knownHostKeys() []ssh.PublicKey { defer func() { if err := file.Close(); err != nil { - log.Debugf("failed to close file '%s': %s", file.Name(), err) + logger.Debugf("failed to close file '%s': %s", file.Name(), err) } }() @@ -290,7 +301,7 @@ func (c Client) knownHostKeys() []ssh.PublicKey { if contains(hostnames, c.HostKeyAlias, c.OriginalHost, c.Host) { hostKey, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes()) if err != nil { - log.Warnf("failed to parse %q: %s", fields[2], err) + logger.Warnf("failed to parse %q: %s", fields[2], err) } else { hostKeys = append(hostKeys, hostKey) } @@ -299,7 +310,7 @@ func (c Client) knownHostKeys() []ssh.PublicKey { return nil }(filename); err != nil { - log.Debugln(err) + logger.Debugln(err) } } @@ -394,29 +405,33 @@ func (c Client) signerForIdentity() (ssh.Signer, error) { return signer, nil } -func (c Client) warnIfUsingRevokedHostKeys() { +func (c Client) warnIfUsingRevokedHostKeys(ctx context.Context) { + logger := log.Extract(ctx) + revokedKeysFile := ssh_config.Get(c.OriginalHost, "RevokedHostKeys") if revokedKeysFile != "" { - log.Warnln("Using ssh config RevokedHostKeys is not supported") + logger.Warnln("using ssh config RevokedHostKeys is not supported") return } if c.OriginalHost != c.Host { revokedKeysFile = ssh_config.Get(c.Host, "RevokedHostKeys") if revokedKeysFile != "" { - log.Warnln("Using ssh config RevokedHostKeys is not supported") + logger.Warnln("using ssh config RevokedHostKeys is not supported") } } } -func (c Client) sshClient() (*ssh.Client, error) { +func (c Client) sshClient(ctx context.Context) (*ssh.Client, error) { + logger := log.Extract(ctx) + var auths []ssh.AuthMethod addr := fmt.Sprintf("%s:%d", c.Host, c.Port) signer, err := c.signerForIdentity() if err != nil { - log.Warnf("%s", err) + logger.Warnf("%s", err) } if signer != nil { @@ -441,10 +456,10 @@ func (c Client) sshClient() (*ssh.Client, error) { } strict := c.strictHostKeyChecking() - log.Debugf("StrictHostKeyChecking for %s set to %s", c.OriginalHost, strict) + logger.Debugf("StrictHostKeyChecking for %s set to %s", c.OriginalHost, strict) if strict == "no" { - log.Debugf("host key checking disabled for %s", c.OriginalHost) + logger.Debugf("host key checking disabled for %s", c.OriginalHost) config.HostKeyCallback = ssh.InsecureIgnoreHostKey() // nolint:gosec @@ -457,13 +472,13 @@ func (c Client) sshClient() (*ssh.Client, error) { return client, nil } - knownHostKeys := c.knownHostKeys() + knownHostKeys := c.knownHostKeys(ctx) if len(knownHostKeys) == 0 && strict == "yes" { return nil, fmt.Errorf("known host key not found for %s, will not connect", c.OriginalHost) } if len(knownHostKeys) == 0 { - log.Debugf("no known host key found for %s, will connect anyway", c.OriginalHost) + logger.Debugf("no known host key found for %s, will connect anyway", c.OriginalHost) config.HostKeyCallback = ssh.InsecureIgnoreHostKey() // nolint:gosec @@ -476,9 +491,9 @@ func (c Client) sshClient() (*ssh.Client, error) { return client, nil } - log.Debugf("found %d known host ssh keys for %s", len(knownHostKeys), c.OriginalHost) + logger.Debugf("found %d known host ssh keys for %s", len(knownHostKeys), c.OriginalHost) - c.warnIfUsingRevokedHostKeys() + c.warnIfUsingRevokedHostKeys(ctx) for _, hostKey := range knownHostKeys { config.HostKeyCallback = ssh.FixedHostKey(hostKey) @@ -486,7 +501,7 @@ func (c Client) sshClient() (*ssh.Client, error) { // Connect to server client, err := dial(addr, &config) if err != nil { - log.Warnf("failed to connect to '%s': %s", addr, err) + logger.Warnf("failed to connect to '%s': %s", addr, err) continue } @@ -519,7 +534,7 @@ func (c Client) user() string { return "" } -func hostKeyAlias(hostOriginal string, hostDerived string) string { +func hostKeyAlias(ctx context.Context, hostOriginal, hostDerived string) string { alias := ssh_config.Get(hostOriginal, "HostKeyAlias") if alias == "" && hostOriginal != hostDerived { alias = ssh_config.Get(hostDerived, "HostKeyAlias") @@ -529,9 +544,11 @@ func hostKeyAlias(hostOriginal string, hostDerived string) string { return "" } + logger := log.Extract(ctx) + alias, err := homedir.Expand(alias) if err != nil { - log.Debugf("Unable to expand home directory for HostKeyAlias %q: %w", alias, err) + logger.Debugf("Unable to expand home directory for HostKeyAlias %q: %s", alias, err) } return alias diff --git a/pkg/remote/remote_test.go b/pkg/remote/remote_test.go index 2215e4ee..d1fed521 100644 --- a/pkg/remote/remote_test.go +++ b/pkg/remote/remote_test.go @@ -1,7 +1,7 @@ package remote_test import ( - "bytes" + "context" "encoding/hex" "fmt" "io" @@ -14,21 +14,27 @@ import ( "strings" "testing" + "github.com/wakatime/wakatime-cli/cmd" "github.com/wakatime/wakatime-cli/pkg/filter" "github.com/wakatime/wakatime-cli/pkg/heartbeat" "github.com/wakatime/wakatime-cli/pkg/log" "github.com/wakatime/wakatime-cli/pkg/regex" "github.com/wakatime/wakatime-cli/pkg/remote" + "github.com/wakatime/wakatime-cli/pkg/windows" "github.com/kevinburke/ssh_config" "github.com/pkg/sftp" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) func TestNewClient(t *testing.T) { - client, err := remote.NewClient("ssh://wakatime:1234@192.168.1.2:222/home/pi/unicorn-hat/examples/ascii_pic.py") + client, err := remote.NewClient( + context.Background(), + "ssh://wakatime:1234@192.168.1.2:222/home/pi/unicorn-hat/examples/ascii_pic.py", + ) require.NoError(t, err) assert.Equal(t, remote.Client{ @@ -42,7 +48,7 @@ func TestNewClient(t *testing.T) { } func TestNewClient_Sftp(t *testing.T) { - client, err := remote.NewClient("sftp://127.0.0.1") + client, err := remote.NewClient(context.Background(), "sftp://127.0.0.1") require.NoError(t, err) assert.Equal(t, remote.Client{ @@ -56,7 +62,7 @@ func TestNewClient_Sftp(t *testing.T) { } func TestNewClient_Err(t *testing.T) { - _, err := remote.NewClient("ssh://wakatime:1234@192.168.1.2:port") + _, err := remote.NewClient(context.Background(), "ssh://wakatime:1234@192.168.1.2:port") require.Error(t, err) assert.EqualError(t, err, @@ -64,10 +70,6 @@ func TestNewClient_Err(t *testing.T) { } func TestWithDetection_SshConfig_Hostname(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping because OS is Windows.") - } - shutdown, host, port := testServer(t, false) defer shutdown() @@ -90,14 +92,24 @@ func TestWithDetection_SshConfig_Hostname(t *testing.T) { err = os.WriteFile(tmpFile.Name(), []byte(fmt.Sprintf(string(template), host)), 0600) require.NoError(t, err) - entity, _ := filepath.Abs("./testdata/main.go") + entityFilepath, err := filepath.Abs("./testdata/main.go") + require.NoError(t, err) + + entity := "ssh://user:pass@example.com:" + strconv.Itoa(port) + + if runtime.GOOS == "windows" { + entityFilepath = windows.FormatFilePath(entityFilepath) + entity += "/" + entityFilepath + } else { + entity += entityFilepath + } sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Category: heartbeat.CodingCategory, - Entity: "ssh://user:pass@example.com:" + strconv.Itoa(port) + entity, + Entity: entity, EntityType: heartbeat.FileType, LocalFile: hh[0].LocalFile, LocalFileNeedsCleanup: true, @@ -120,10 +132,10 @@ func TestWithDetection_SshConfig_Hostname(t *testing.T) { } handle := heartbeat.NewHandle(&sender, opts...) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(context.Background(), []heartbeat.Heartbeat{ { Category: heartbeat.CodingCategory, - Entity: "ssh://user:pass@example.com:" + strconv.Itoa(port) + entity, + Entity: entity, EntityType: heartbeat.FileType, Time: 1585598060, UserAgent: "wakatime/13.0.7", @@ -133,23 +145,34 @@ func TestWithDetection_SshConfig_Hostname(t *testing.T) { } func TestWithDetection_SshConfig_UserKnownHostsFile_Mismatch(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping because OS is Windows.") - } - - logs := bytes.NewBuffer(nil) - - teardownLogCapture := captureLogs(logs) - defer teardownLogCapture() + tmpDir := t.TempDir() shutdown, host, port := testServer(t, true) defer shutdown() - tmpFile, err := os.CreateTemp(t.TempDir(), "") + tmpFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) defer tmpFile.Close() + logFile, err := os.CreateTemp(tmpDir, "") + require.NoError(t, err) + + defer logFile.Close() + + ctx := context.Background() + + v := viper.New() + v.Set("log-file", logFile.Name()) + v.Set("verbose", true) + + logger, err := cmd.SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) + ssh_config.DefaultUserSettings = &ssh_config.UserSettings{ IgnoreErrors: false, } @@ -167,10 +190,20 @@ func TestWithDetection_SshConfig_UserKnownHostsFile_Mismatch(t *testing.T) { err = os.WriteFile(tmpFile.Name(), []byte(fmt.Sprintf(string(template), host, knownHostsFile)), 0600) require.NoError(t, err) - entity, _ := filepath.Abs("./testdata/main.go") + entityFilepath, err := filepath.Abs("./testdata/main.go") + require.NoError(t, err) + + entity := "ssh://user:pass@github.com:" + strconv.Itoa(port) + + if runtime.GOOS == "windows" { + entityFilepath = windows.FormatFilePath(entityFilepath) + entity += "/" + entityFilepath + } else { + entity += entityFilepath + } sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Empty(t, hh) return []heartbeat.Result{}, nil }, @@ -184,34 +217,32 @@ func TestWithDetection_SshConfig_UserKnownHostsFile_Mismatch(t *testing.T) { } handle := heartbeat.NewHandle(&sender, opts...) - results, err := handle([]heartbeat.Heartbeat{ + results, err := handle(ctx, []heartbeat.Heartbeat{ { Category: heartbeat.CodingCategory, - Entity: "ssh://user:pass@github.com:" + strconv.Itoa(port) + entity, + Entity: entity, EntityType: heartbeat.FileType, Time: 1585598060, UserAgent: "wakatime/13.0.7", }, }) require.NoError(t, err) + assert.Empty(t, results) - assert.Contains(t, logs.String(), "ssh: handshake failed: ssh: host key mismatch") -} -func TestWithDetection_SshConfig_UserKnownHostsFile_Match(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping because OS is Windows.") - } + output, err := io.ReadAll(logFile) + require.NoError(t, err) - logs := bytes.NewBuffer(nil) + assert.Contains(t, string(output), "ssh: handshake failed: ssh: host key mismatch") +} - teardownLogCapture := captureLogs(logs) - defer teardownLogCapture() +func TestWithDetection_SshConfig_UserKnownHostsFile_Match(t *testing.T) { + tmpDir := t.TempDir() shutdown, host, port := testServer(t, true) defer shutdown() - tmpFile, err := os.CreateTemp(t.TempDir(), "") + tmpFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) defer tmpFile.Close() @@ -233,14 +264,24 @@ func TestWithDetection_SshConfig_UserKnownHostsFile_Match(t *testing.T) { err = os.WriteFile(tmpFile.Name(), []byte(fmt.Sprintf(string(template), host, knownHostsFile)), 0600) require.NoError(t, err) - entity, _ := filepath.Abs("./testdata/main.go") + entityFilepath, err := filepath.Abs("./testdata/main.go") + require.NoError(t, err) + + entity := "ssh://user:pass@example.com:" + strconv.Itoa(port) + + if runtime.GOOS == "windows" { + entityFilepath = windows.FormatFilePath(entityFilepath) + entity += "/" + entityFilepath + } else { + entity += entityFilepath + } sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Equal(t, []heartbeat.Heartbeat{ { Category: heartbeat.CodingCategory, - Entity: "ssh://user:pass@example.com:" + strconv.Itoa(port) + entity, + Entity: entity, EntityType: heartbeat.FileType, LocalFile: hh[0].LocalFile, LocalFileNeedsCleanup: true, @@ -268,29 +309,21 @@ func TestWithDetection_SshConfig_UserKnownHostsFile_Match(t *testing.T) { } handle := heartbeat.NewHandle(&sender, opts...) - results, err := handle([]heartbeat.Heartbeat{ + results, err := handle(context.Background(), []heartbeat.Heartbeat{ { Category: heartbeat.CodingCategory, - Entity: "ssh://user:pass@example.com:" + strconv.Itoa(port) + entity, + Entity: entity, EntityType: heartbeat.FileType, Time: 1585598060, UserAgent: "wakatime/13.0.7", }, }) require.NoError(t, err) + assert.Len(t, results, 1) } func TestWithDetection_Filtered(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping because OS is Windows.") - } - - logs := bytes.NewBuffer(nil) - - teardownLogCapture := captureLogs(logs) - defer teardownLogCapture() - shutdown, host, port := testServer(t, true) defer shutdown() @@ -319,7 +352,7 @@ func TestWithDetection_Filtered(t *testing.T) { entity, _ := filepath.Abs("./testdata/main.go") sender := mockSender{ - SendHeartbeatsFn: func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { assert.Empty(t, hh) return []heartbeat.Result{}, nil }, @@ -327,7 +360,7 @@ func TestWithDetection_Filtered(t *testing.T) { opts := []heartbeat.HandleOption{ filter.WithFiltering(filter.Config{ - Exclude: []regex.Regex{regexp.MustCompile(".*")}, + Exclude: []regex.Regex{regex.NewRegexpWrap(regexp.MustCompile(".*"))}, Include: nil, IncludeOnlyWithProjectFile: true, }), @@ -335,7 +368,7 @@ func TestWithDetection_Filtered(t *testing.T) { } handle := heartbeat.NewHandle(&sender, opts...) - results, err := handle([]heartbeat.Heartbeat{ + results, err := handle(context.Background(), []heartbeat.Heartbeat{ { Category: heartbeat.CodingCategory, Entity: "ssh://user:pass@example.com:" + strconv.Itoa(port) + entity, @@ -345,6 +378,7 @@ func TestWithDetection_Filtered(t *testing.T) { }, }) require.NoError(t, err) + assert.Empty(t, results) } @@ -354,14 +388,12 @@ func TestWithCleanup_NotTemporary(t *testing.T) { tmpFile.Close() - defer os.Remove(tmpFile.Name()) - opts := []heartbeat.HandleOption{ remote.WithCleanup(), } sender := mockSender{ - SendHeartbeatsFn: func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -375,7 +407,7 @@ func TestWithCleanup_NotTemporary(t *testing.T) { assert.FileExists(t, tmpFile.Name()) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(context.Background(), []heartbeat.Heartbeat{ { LocalFile: tmpFile.Name(), }, @@ -397,7 +429,7 @@ func TestWithCleanup(t *testing.T) { } sender := mockSender{ - SendHeartbeatsFn: func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -411,7 +443,7 @@ func TestWithCleanup(t *testing.T) { assert.FileExists(t, tmpFile.Name()) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(context.Background(), []heartbeat.Heartbeat{ { LocalFile: tmpFile.Name(), LocalFileNeedsCleanup: true, @@ -433,7 +465,7 @@ func TestWithCleanup_NotRemoteFile(t *testing.T) { } sender := mockSender{ - SendHeartbeatsFn: func(_ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { + SendHeartbeatsFn: func(_ context.Context, _ []heartbeat.Heartbeat) ([]heartbeat.Result, error) { return []heartbeat.Result{ { Status: 201, @@ -445,7 +477,7 @@ func TestWithCleanup_NotRemoteFile(t *testing.T) { handle := heartbeat.NewHandle(&sender, opts...) - _, err = handle([]heartbeat.Heartbeat{ + _, err = handle(context.Background(), []heartbeat.Heartbeat{ { LocalFile: tmpFile.Name(), }, @@ -456,13 +488,13 @@ func TestWithCleanup_NotRemoteFile(t *testing.T) { } type mockSender struct { - SendHeartbeatsFn func(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) + SendHeartbeatsFn func(context.Context, []heartbeat.Heartbeat) ([]heartbeat.Result, error) SendHeartbeatsFnInvoked bool } -func (m *mockSender) SendHeartbeats(hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { +func (m *mockSender) SendHeartbeats(ctx context.Context, hh []heartbeat.Heartbeat) ([]heartbeat.Result, error) { m.SendHeartbeatsFnInvoked = true - return m.SendHeartbeatsFn(hh) + return m.SendHeartbeatsFn(ctx, hh) } func keyAuth(_ ssh.ConnMetadata, _ ssh.PublicKey) (*ssh.Permissions, error) { @@ -782,16 +814,3 @@ func testServer(t *testing.T, expectError bool) (func(), string, int) { return func() { close(shutdown); listener.Close() }, host, port } - -func captureLogs(dest io.Writer) func() { - logOutput := log.Output() - - // will write to log output and dest - mw := io.MultiWriter(logOutput, dest) - - log.SetOutput(mw) - - return func() { - log.SetOutput(logOutput) - } -} diff --git a/pkg/system/system_linux.go b/pkg/system/system_linux.go index 5713f038..468d0350 100644 --- a/pkg/system/system_linux.go +++ b/pkg/system/system_linux.go @@ -3,6 +3,7 @@ package system import ( + "context" "fmt" "runtime" "strings" @@ -12,14 +13,16 @@ import ( ) // OSName returns the runtime machine's operating system name. -func OSName() string { +func OSName(ctx context.Context) string { os := runtime.GOOS var buf syscall.Utsname + logger := log.Extract(ctx) + err := syscall.Uname(&buf) if err != nil { - log.Debugf("Uname error: %s", err) + logger.Debugf("Uname error: %s", err) return os } diff --git a/pkg/system/system_other.go b/pkg/system/system_other.go index 5a41c3c2..42289e7f 100644 --- a/pkg/system/system_other.go +++ b/pkg/system/system_other.go @@ -3,10 +3,11 @@ package system import ( + "context" "runtime" ) // OSName returns the runtime machine's operating system name. -func OSName() string { +func OSName(_ context.Context) string { return runtime.GOOS } diff --git a/pkg/system/system_other_test.go b/pkg/system/system_other_test.go new file mode 100644 index 00000000..4a9f34be --- /dev/null +++ b/pkg/system/system_other_test.go @@ -0,0 +1,22 @@ +//go:build !linux + +package system_test + +import ( + "context" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/wakatime/wakatime-cli/pkg/system" +) + +func TestOSName(t *testing.T) { + if runtime.GOOS != "darwin" && runtime.GOOS != "windows" { + t.Skip("skipping test on non-darwin and non-windows system") + } + + name := system.OSName(context.Background()) + + assert.Equal(t, runtime.GOOS, name) +} From 9e52db756d334390c6457fa582e6dd091d64d2ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Henrique=20Guard=C3=A3o=20Gandarez?= Date: Thu, 7 Nov 2024 10:21:18 -0300 Subject: [PATCH 2/6] Fix flakey tests --- cmd/heartbeat/heartbeat_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cmd/heartbeat/heartbeat_test.go b/cmd/heartbeat/heartbeat_test.go index e0832009..47f65e04 100644 --- a/cmd/heartbeat/heartbeat_test.go +++ b/cmd/heartbeat/heartbeat_test.go @@ -125,6 +125,8 @@ func TestSendHeartbeats(t *testing.T) { offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) + defer offlineQueueFile.Close() + err = cmdheartbeat.SendHeartbeats(context.Background(), v, offlineQueueFile.Name()) require.NoError(t, err) @@ -150,12 +152,18 @@ func TestSendHeartbeats_RateLimited(t *testing.T) { tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime-config") require.NoError(t, err) + defer tmpFile.Close() + tmpFileInternal, err := os.CreateTemp(t.TempDir(), "wakatime-internal-config") require.NoError(t, err) + defer tmpFileInternal.Close() + offlineQueueFile, err := os.CreateTemp(t.TempDir(), "offline-queue-file") require.NoError(t, err) + defer offlineQueueFile.Close() + v := viper.New() v.SetDefault("sync-offline-activity", 1000) v.Set("api-url", testServerURL) @@ -216,6 +224,8 @@ func TestSendHeartbeats_WithFiltering_Exclude(t *testing.T) { offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) + defer offlineQueueFile.Close() + err = cmdheartbeat.SendHeartbeats(context.Background(), v, offlineQueueFile.Name()) require.NoError(t, err) From 9a80ec24c91da0dc645836c1711a3d38dd030bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Henrique=20Guard=C3=A3o=20Gandarez?= Date: Thu, 7 Nov 2024 10:59:01 -0300 Subject: [PATCH 3/6] Fix flakey tests --- main_test.go | 28 +++++++++++++++++----------- pkg/offline/legacy.go | 5 +++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/main_test.go b/main_test.go index f491d36d..3127f8d6 100644 --- a/main_test.go +++ b/main_test.go @@ -116,7 +116,8 @@ func testSendHeartbeats(t *testing.T, projectFolder, entity, p string) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFileLegacy.Close() + // close the file to avoid "file already closed" error + offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") require.NoError(t, err) @@ -215,7 +216,8 @@ func TestSendHeartbeats_SecondaryApiKey(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFileLegacy.Close() + // close the file to avoid "file already closed" error + offlineQueueFileLegacy.Close() tmpInternalConfigFile, err := os.CreateTemp(tmpDir, "wakatime-internal.cfg") require.NoError(t, err) @@ -293,7 +295,8 @@ func TestSendHeartbeats_ExtraHeartbeats(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFileLegacy.Close() + // close the file to avoid "file already exists" error on Windows + offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") require.NoError(t, err) @@ -385,7 +388,7 @@ func TestSendHeartbeats_ExtraHeartbeats_SyncLegacyOfflineActivity(t *testing.T) offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "legacy-offline-file") require.NoError(t, err) - // early close to avoid file locking in Windows + // close the file to avoid "file already exists" error on Windows offlineQueueFileLegacy.Close() db, err := bolt.Open(offlineQueueFileLegacy.Name(), 0600, nil) @@ -438,8 +441,6 @@ func TestSendHeartbeats_ExtraHeartbeats_SyncLegacyOfflineActivity(t *testing.T) buffer := bytes.NewBuffer(data) - assert.FileExists(t, offlineQueueFileLegacy.Name()) - runWakatimeCli( t, buffer, @@ -529,7 +530,8 @@ func TestSendHeartbeats_Err(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFileLegacy.Close() + // close the file to avoid "file already exists" error on Windows + offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") require.NoError(t, err) @@ -590,7 +592,8 @@ func TestSendHeartbeats_ErrAuth_InvalidAPIKEY(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFileLegacy.Close() + // close the file to avoid "file already exists" error on Windows + offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") require.NoError(t, err) @@ -650,7 +653,8 @@ func TestSendHeartbeats_MalformedConfig(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFileLegacy.Close() + // close the file to avoid "file already exists" error on Windows + offlineQueueFileLegacy.Close() out := runWakatimeCliExpectErr( t, @@ -684,7 +688,8 @@ func TestSendHeartbeats_MalformedInternalConfig(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFileLegacy.Close() + // close the file to avoid "file already exists" error on Windows + offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") require.NoError(t, err) @@ -1017,7 +1022,8 @@ func TestPrintOfflineHeartbeats(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFileLegacy.Close() + // close the file to avoid "file already exists" error on Windows + offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") require.NoError(t, err) diff --git a/pkg/offline/legacy.go b/pkg/offline/legacy.go index 7320b02c..17f61542 100644 --- a/pkg/offline/legacy.go +++ b/pkg/offline/legacy.go @@ -5,10 +5,11 @@ import ( "fmt" "path/filepath" - "github.com/mitchellh/go-homedir" - "github.com/spf13/viper" "github.com/wakatime/wakatime-cli/pkg/ini" "github.com/wakatime/wakatime-cli/pkg/vipertools" + + "github.com/mitchellh/go-homedir" + "github.com/spf13/viper" ) // dbLegacyFilename is the legacy bolt db filename. From 490ad2c5060d9f0d3638c2c9126e88fcfde07a42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Henrique=20Guard=C3=A3o=20Gandarez?= Date: Fri, 8 Nov 2024 18:06:39 -0300 Subject: [PATCH 4/6] Replace logrus by uber-go/zap --- cmd/fileexperts/fileexperts_test.go | 3 - cmd/run.go | 18 ++-- cmd/run_internal_test.go | 6 +- go.mod | 2 + go.sum | 6 ++ main_test.go | 18 ++-- pkg/log/context.go | 9 +- pkg/log/log.go | 155 +++++++++++++--------------- pkg/log/log_test.go | 12 +-- pkg/log/option.go | 25 +++++ pkg/log/writer.go | 41 ++++++++ pkg/offline/offline.go | 4 +- pkg/project/project.go | 6 +- pkg/remote/remote.go | 4 +- 14 files changed, 190 insertions(+), 119 deletions(-) create mode 100644 pkg/log/option.go create mode 100644 pkg/log/writer.go diff --git a/cmd/fileexperts/fileexperts_test.go b/cmd/fileexperts/fileexperts_test.go index 423e4670..13c6d300 100644 --- a/cmd/fileexperts/fileexperts_test.go +++ b/cmd/fileexperts/fileexperts_test.go @@ -121,9 +121,6 @@ func TestFileExperts_NonExistingEntity(t *testing.T) { _, err = fileexperts.FileExperts(ctx, v) require.NoError(t, err) - err = logFile.Sync() - require.NoError(t, err) - output, err := io.ReadAll(logFile) require.NoError(t, err) diff --git a/cmd/run.go b/cmd/run.go index 15fde8c2..d191316f 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -242,9 +242,7 @@ func SetupLogging(ctx context.Context, v *viper.Viper) (*log.Logger, error) { return nil, fmt.Errorf("failed to load log params: %s", err) } - logger := log.New(params.Verbose, params.SendDiagsOnErrors, params.Metrics) - - logFile := os.Stdout + destOutput := os.Stdout if !params.ToStdout { dir := filepath.Dir(params.File) @@ -255,16 +253,20 @@ func SetupLogging(ctx context.Context, v *viper.Viper) (*log.Logger, error) { } } - logFile, err = os.OpenFile(params.File, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) // nolint:gosec + destOutput, err = os.OpenFile(params.File, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) // nolint:gosec if err != nil { return nil, fmt.Errorf("error opening log file: %s", err) } - - logger.SetOutput(logFile) } - logger.SetVerbose(params.Verbose) - log.SetJww(params.Verbose, logFile) + logger := log.New( + destOutput, + log.WithVerbose(params.Verbose), + log.WithSendDiagsOnErrors(params.SendDiagsOnErrors), + log.WithMetrics(params.Metrics), + ) + + log.SetJww(params.Verbose, destOutput) return logger, nil } diff --git a/cmd/run_internal_test.go b/cmd/run_internal_test.go index 7ead4ebf..28885ef0 100644 --- a/cmd/run_internal_test.go +++ b/cmd/run_internal_test.go @@ -135,7 +135,8 @@ func TestRunCmd_BackoffLoggedWithVerbose(t *testing.T) { offlineQueueFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFile.Close() + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows + offlineQueueFile.Close() entity, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) @@ -193,7 +194,8 @@ func TestRunCmd_BackoffNotLogged(t *testing.T) { offlineQueueFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer offlineQueueFile.Close() + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows + offlineQueueFile.Close() logFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) diff --git a/go.mod b/go.mod index 9ce2e3a0..56a47ea8 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.8.4 go.etcd.io/bbolt v1.3.8 + go.uber.org/zap v1.27.0 golang.org/x/crypto v0.24.0 golang.org/x/net v0.26.0 golang.org/x/text v0.16.0 @@ -51,6 +52,7 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/yookoala/realpath v1.0.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/sys v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9a3d9cb2..56c816d9 100644 --- a/go.sum +++ b/go.sum @@ -285,6 +285,12 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20180214000028-650f4a345ab4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/main_test.go b/main_test.go index 3127f8d6..acedcaaf 100644 --- a/main_test.go +++ b/main_test.go @@ -116,7 +116,7 @@ func testSendHeartbeats(t *testing.T, projectFolder, entity, p string) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - // close the file to avoid "file already closed" error + // close the file to avoid "The process cannot access the file because it is being used by another process" error offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") @@ -216,7 +216,7 @@ func TestSendHeartbeats_SecondaryApiKey(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - // close the file to avoid "file already closed" error + // close the file to avoid "The process cannot access the file because it is being used by another process" error offlineQueueFileLegacy.Close() tmpInternalConfigFile, err := os.CreateTemp(tmpDir, "wakatime-internal.cfg") @@ -295,7 +295,7 @@ func TestSendHeartbeats_ExtraHeartbeats(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - // close the file to avoid "file already exists" error on Windows + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") @@ -388,7 +388,7 @@ func TestSendHeartbeats_ExtraHeartbeats_SyncLegacyOfflineActivity(t *testing.T) offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "legacy-offline-file") require.NoError(t, err) - // close the file to avoid "file already exists" error on Windows + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows offlineQueueFileLegacy.Close() db, err := bolt.Open(offlineQueueFileLegacy.Name(), 0600, nil) @@ -530,7 +530,7 @@ func TestSendHeartbeats_Err(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - // close the file to avoid "file already exists" error on Windows + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") @@ -592,7 +592,7 @@ func TestSendHeartbeats_ErrAuth_InvalidAPIKEY(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - // close the file to avoid "file already exists" error on Windows + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") @@ -653,7 +653,7 @@ func TestSendHeartbeats_MalformedConfig(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - // close the file to avoid "file already exists" error on Windows + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows offlineQueueFileLegacy.Close() out := runWakatimeCliExpectErr( @@ -688,7 +688,7 @@ func TestSendHeartbeats_MalformedInternalConfig(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - // close the file to avoid "file already exists" error on Windows + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") @@ -1022,7 +1022,7 @@ func TestPrintOfflineHeartbeats(t *testing.T) { offlineQueueFileLegacy, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - // close the file to avoid "file already exists" error on Windows + // close to avoid "The process cannot access the file because it is being used by another process" error on Windows offlineQueueFileLegacy.Close() tmpConfigFile, err := os.CreateTemp(tmpDir, "wakatime.cfg") diff --git a/pkg/log/context.go b/pkg/log/context.go index 55e18c6f..c940ca9f 100644 --- a/pkg/log/context.go +++ b/pkg/log/context.go @@ -1,6 +1,9 @@ package log -import "context" +import ( + "context" + "os" +) type ( ctxMarker struct{} @@ -17,7 +20,9 @@ var ctxMarkerKey = &ctxMarker{} func Extract(ctx context.Context) *Logger { l, ok := ctx.Value(ctxMarkerKey).(*ctxLogger) if !ok || l == nil { - return New(false, false, false) + // TODO: It should never happen but if it does, + // we should find a way to create a new logger using passed params during initialization + return New(os.Stdout) } return l.logger diff --git a/pkg/log/log.go b/pkg/log/log.go index e625fa04..66c45009 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -3,77 +3,65 @@ package log import ( "fmt" "io" - "os" - "runtime" - "strings" - "github.com/wakatime/wakatime-cli/pkg/version" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" - "github.com/sirupsen/logrus" jww "github.com/spf13/jwalterweatherman" + "github.com/wakatime/wakatime-cli/pkg/version" ) // Logger is the log entry. type Logger struct { - entry *logrus.Entry - metrics bool - sendDiagsOnErrors bool - verbose bool -} + entry *zap.Logger + atomicLevel zap.AtomicLevel + currentOutput io.Writer + dynamicWriteSyncer *DynamicWriteSyncer + metrics bool + sendDiagsOnErrors bool + verbose bool +} + +// New creates a new Logger that writes to dest. +func New(dest io.Writer, opts ...Option) *Logger { + atom := zap.NewAtomicLevel() + dynamicWriteSyncer := NewDynamicWriteSyncer(zapcore.AddSync(dest)) + + encoderCfg := zap.NewProductionEncoderConfig() + encoderCfg.TimeKey = "now" + encoderCfg.EncodeTime = zapcore.RFC3339TimeEncoder + encoderCfg.MessageKey = "message" + encoderCfg.FunctionKey = "func" + + l := zap.New(zapcore.NewCore( + zapcore.NewJSONEncoder(encoderCfg), + dynamicWriteSyncer, + atom, + ), + zap.AddCaller(), + zap.AddCallerSkip(1), + zap.AddStacktrace(zap.FatalLevel), + ) + + l = l.With( + zap.String("version", version.Version), + zap.String("os/arch", fmt.Sprintf("%s/%s", version.OS, version.Arch)), + ) -// New creates a new Logger. -func New(verbose, sendDiagsOnErrors, metrics bool) *Logger { logger := &Logger{ - entry: new(), - metrics: metrics, - sendDiagsOnErrors: sendDiagsOnErrors, - verbose: verbose, + entry: l, + atomicLevel: atom, + currentOutput: dest, + dynamicWriteSyncer: dynamicWriteSyncer, } - logger.SetVerbose(verbose) + for _, option := range opts { + option(logger) + } return logger } -func new() *logrus.Entry { - entry := logrus.NewEntry(&logrus.Logger{ - Out: os.Stdout, - Formatter: &logrus.JSONFormatter{ - FieldMap: logrus.FieldMap{ - logrus.FieldKeyTime: "now", - logrus.FieldKeyFile: "caller", - logrus.FieldKeyMsg: "message", - }, - DisableHTMLEscape: true, - CallerPrettyfier: func(f *runtime.Frame) (string, string) { - // Simplifies function description by removing dangling func name from it. - lastSlash := strings.LastIndexByte(f.Function, '/') - if lastSlash < 0 { - lastSlash = 0 - } - parts := strings.Split(f.Function[lastSlash+1:], ".") - - // Simplifies file path by removing base path from it. - lastPath := strings.LastIndex(f.File, "wakatime-cli/") - if lastPath < 0 { - lastPath = 0 - } - file := f.File[lastPath+13:] - - return fmt.Sprintf("%s.%s", parts[0], parts[1]), - fmt.Sprintf("%s:%d", file, f.Line) - }, - }, - Level: logrus.InfoLevel, - ExitFunc: os.Exit, - ReportCaller: true, - }) - entry.Data["version"] = version.Version - entry.Data["os/arch"] = fmt.Sprintf("%s/%s", version.OS, version.Arch) - - return entry -} - // IsMetricsEnabled returns true if it should collect metrics. func (l *Logger) IsMetricsEnabled() bool { return l.metrics @@ -86,7 +74,7 @@ func (l *Logger) IsVerboseEnabled() bool { // Output returns the current log output. func (l *Logger) Output() io.Writer { - return l.entry.Logger.Out + return l.currentOutput } // SendDiagsOnErrors returns true if diagnostics should be sent on errors. @@ -96,27 +84,30 @@ func (l *Logger) SendDiagsOnErrors() bool { // SetOutput defines sets the log output to io.Writer. func (l *Logger) SetOutput(w io.Writer) { - l.entry.Logger.Out = w + l.currentOutput = w + l.dynamicWriteSyncer.SetWriter(zapcore.AddSync(w)) } // SetVerbose sets log level to debug if enabled. func (l *Logger) SetVerbose(verbose bool) { + l.verbose = verbose + if verbose { - l.entry.Logger.SetLevel(logrus.DebugLevel) + l.atomicLevel.SetLevel(zap.DebugLevel) } else { - l.entry.Logger.SetLevel(logrus.InfoLevel) + l.atomicLevel.SetLevel(zap.InfoLevel) } } // Flush flushes the log output and closes the file. func (l *Logger) Flush() { - if file, ok := l.entry.Logger.Out.(*os.File); ok { - if err := file.Sync(); err != nil { - l.entry.Debugf("failed to flush log file: %s", err) - } + if err := l.entry.Sync(); err != nil { + l.Debugf("failed to flush log file: %s", err) + } - if err := file.Close(); err != nil { - l.entry.Debugf("failed to close log file: %s", err) + if closer, ok := l.currentOutput.(io.Closer); ok { + if err := closer.Close(); err != nil { + l.Debugf("failed to close log file: %s", err) } } } @@ -134,55 +125,55 @@ func SetJww(verbose bool, w io.Writer) { // Debugf logs a message at level Debug. func (l *Logger) Debugf(format string, args ...any) { - l.entry.Debugf(format, args...) + l.entry.Log(zapcore.DebugLevel, fmt.Sprintf(format, args...)) } // Infof logs a message at level Info. func (l *Logger) Infof(format string, args ...any) { - l.entry.Infof(format, args...) + l.entry.Log(zapcore.InfoLevel, fmt.Sprintf(format, args...)) } // Warnf logs a message at level Warn. func (l *Logger) Warnf(format string, args ...any) { - l.entry.Warnf(format, args...) + l.entry.Log(zapcore.WarnLevel, fmt.Sprintf(format, args...)) } // Errorf logs a message at level Error. func (l *Logger) Errorf(format string, args ...any) { - l.entry.Errorf(format, args...) + l.entry.Log(zapcore.ErrorLevel, fmt.Sprintf(format, args...)) } // Fatalf logs a message at level Fatal then the process will exit with status set to 1. func (l *Logger) Fatalf(format string, args ...any) { - l.entry.Fatalf(format, args...) + l.entry.Log(zapcore.FatalLevel, fmt.Sprintf(format, args...)) } // Debugln logs a message at level Debug. -func (l *Logger) Debugln(args ...any) { - l.entry.Debugln(args...) +func (l *Logger) Debugln(msg string) { + l.entry.Log(zapcore.DebugLevel, msg) } // Infoln logs a message at level Info. -func (l *Logger) Infoln(args ...any) { - l.entry.Infoln(args...) +func (l *Logger) Infoln(msg string) { + l.entry.Log(zapcore.InfoLevel, msg) } // Warnln logs a message at level Warn. -func (l *Logger) Warnln(args ...any) { - l.entry.Warnln(args...) +func (l *Logger) Warnln(msg string) { + l.entry.Log(zapcore.WarnLevel, msg) } // Errorln logs a message at level Error. -func (l *Logger) Errorln(args ...any) { - l.entry.Errorln(args...) +func (l *Logger) Errorln(msg string) { + l.entry.Log(zapcore.ErrorLevel, msg) } // Fatalln logs a message at level Fatal then the process will exit with status set to 1. -func (l *Logger) Fatalln(args ...any) { - l.entry.Fatalln(args...) +func (l *Logger) Fatalln(msg string) { + l.entry.Log(zapcore.FatalLevel, msg) } // WithField adds a single field to the Logger. func (l *Logger) WithField(key string, value any) { - l.entry.Data[key] = value + l.entry = l.entry.With(zap.Any(key, value)) } diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go index 1390ea26..b095c0ff 100644 --- a/pkg/log/log_test.go +++ b/pkg/log/log_test.go @@ -9,37 +9,37 @@ import ( ) func TestLog_IsMetricsEnabled(t *testing.T) { - logger := log.New(false, false, true) + logger := log.New(nil, log.WithMetrics(true)) assert.True(t, logger.IsMetricsEnabled()) } func TestLog_IsMetricsEnabled_Disabled(t *testing.T) { - logger := log.New(false, false, false) + logger := log.New(nil) assert.False(t, logger.IsMetricsEnabled()) } func TestLog_IsVerboseEnabled(t *testing.T) { - logger := log.New(true, false, false) + logger := log.New(nil, log.WithVerbose(true)) assert.True(t, logger.IsVerboseEnabled()) } func TestLog_IsVerboseEnabled_Disabled(t *testing.T) { - logger := log.New(false, false, false) + logger := log.New(nil) assert.False(t, logger.IsVerboseEnabled()) } func TestLog_SendDiagsOnErrors(t *testing.T) { - logger := log.New(false, true, false) + logger := log.New(nil, log.WithSendDiagsOnErrors(true)) assert.True(t, logger.SendDiagsOnErrors()) } func TestLog_SendDiagsOnErrors_Disabled(t *testing.T) { - logger := log.New(false, false, false) + logger := log.New(nil) assert.False(t, logger.SendDiagsOnErrors()) } diff --git a/pkg/log/option.go b/pkg/log/option.go new file mode 100644 index 00000000..0c547b08 --- /dev/null +++ b/pkg/log/option.go @@ -0,0 +1,25 @@ +package log + +// Option is a functional option for Logger. +type Option func(*Logger) + +// WithVerbose sets verbose mode. +func WithVerbose(verbose bool) Option { + return func(l *Logger) { + l.SetVerbose(verbose) + } +} + +// WithMetrics sets metrics mode. +func WithMetrics(metrics bool) Option { + return func(l *Logger) { + l.metrics = metrics + } +} + +// WithSendDiagsOnErrors sets send diagnostics on errors mode. +func WithSendDiagsOnErrors(sendDiagsOnErrors bool) Option { + return func(l *Logger) { + l.sendDiagsOnErrors = sendDiagsOnErrors + } +} diff --git a/pkg/log/writer.go b/pkg/log/writer.go new file mode 100644 index 00000000..0a299836 --- /dev/null +++ b/pkg/log/writer.go @@ -0,0 +1,41 @@ +package log + +import ( + "sync" + + "go.uber.org/zap/zapcore" +) + +// DynamicWriteSyncer allows changing the underlying WriteSyncer at runtime. +type DynamicWriteSyncer struct { + mu sync.RWMutex + writer zapcore.WriteSyncer +} + +// Write writes the log entry to the current writer. +func (d *DynamicWriteSyncer) Write(p []byte) (n int, err error) { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.writer.Write(p) +} + +// Sync calls Sync on the current writer. +func (d *DynamicWriteSyncer) Sync() error { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.writer.Sync() +} + +// SetWriter allows updating the underlying writer at runtime. +func (d *DynamicWriteSyncer) SetWriter(newWriter zapcore.WriteSyncer) { + d.mu.Lock() + defer d.mu.Unlock() + d.writer = newWriter +} + +// NewDynamicWriteSyncer initializes the dynamic writer with an initial writer. +func NewDynamicWriteSyncer(initial zapcore.WriteSyncer) *DynamicWriteSyncer { + return &DynamicWriteSyncer{writer: initial} +} diff --git a/pkg/offline/offline.go b/pkg/offline/offline.go index ffd93c4a..11019f9e 100644 --- a/pkg/offline/offline.go +++ b/pkg/offline/offline.go @@ -437,9 +437,9 @@ func openDB(ctx context.Context, filepath string) (db *bolt.DB, _ func(), err er return nil, nil, fmt.Errorf("failed to open db file: %s", err) } - logger := log.Extract(ctx) - return db, func() { + logger := log.Extract(ctx) + // recover from panic when closing db defer func() { if r := recover(); r != nil { diff --git a/pkg/project/project.go b/pkg/project/project.go index 5967ead5..621a757c 100644 --- a/pkg/project/project.go +++ b/pkg/project/project.go @@ -143,7 +143,7 @@ func WithDetection(config Config) heartbeat.HandleOption { logger := log.Extract(ctx) for n, h := range hh { - logger.Debugln("execute project detection for:", h.Entity) + logger.Debugf("execute project detection for: %s", h.Entity) // first, use .wakatime-project or [projectmap] section with entity path. // Then, detect with project folder. This tries to use the same project name @@ -246,7 +246,7 @@ func Detect(ctx context.Context, patterns []MapPattern, args ...DetecterArg) (Re } for _, p := range configPlugins { - logger.Debugln("execute", p.ID().String()) + logger.Debugf("execute %s", p.ID().String()) result, detected, err := p.Detect(ctx) if err != nil { @@ -296,7 +296,7 @@ func DetectWithRevControl( } for _, p := range revControlPlugins { - logger.Debugln("execute", p.ID().String()) + logger.Debugf("execute %s", p.ID().String()) result, detected, err := p.Detect(ctx) if err != nil { diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 044e489e..30abb534 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -111,7 +111,7 @@ func WithCleanup() heartbeat.HandleOption { for _, h := range hh { if h.LocalFileNeedsCleanup { - logger.Debugln("deleting temporary file:", h.LocalFile) + logger.Debugf("deleting temporary file: %s", h.LocalFile) deleteLocalFile(ctx, h.LocalFile) } @@ -310,7 +310,7 @@ func (c Client) knownHostKeys(ctx context.Context) []ssh.PublicKey { return nil }(filename); err != nil { - logger.Debugln(err) + logger.Debugln(err.Error()) } } From 265065a3877633dd8603742472aadf174c400a86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Henrique=20Guard=C3=A3o=20Gandarez?= Date: Sat, 9 Nov 2024 08:59:17 -0300 Subject: [PATCH 5/6] Add missing tests for cmd/offline --- cmd/heartbeat/heartbeat_test.go | 12 +- cmd/offline/offline_test.go | 113 +++++++++++++++ cmd/offline/testdata/extra_heartbeats.json | 1 + cmd/offline/testdata/localfile.go | 3 + cmd/offline/testdata/main.go | 11 ++ cmd/offline/testdata/main.py | 20 +++ cmd/run_internal_test.go | 160 +++++++++++++++++++++ 7 files changed, 317 insertions(+), 3 deletions(-) create mode 100644 cmd/offline/offline_test.go create mode 100644 cmd/offline/testdata/extra_heartbeats.json create mode 100644 cmd/offline/testdata/localfile.go create mode 100644 cmd/offline/testdata/main.go create mode 100644 cmd/offline/testdata/main.py diff --git a/cmd/heartbeat/heartbeat_test.go b/cmd/heartbeat/heartbeat_test.go index 47f65e04..2b40ff92 100644 --- a/cmd/heartbeat/heartbeat_test.go +++ b/cmd/heartbeat/heartbeat_test.go @@ -52,6 +52,8 @@ func TestSendHeartbeats(t *testing.T) { tmpFile, err := os.CreateTemp(t.TempDir(), "wakatime-config") require.NoError(t, err) + defer tmpFile.Close() + subfolders := project.CountSlashesInProjectFolder(projectFolder) router.HandleFunc("/users/current/heartbeats.bulk", func(w http.ResponseWriter, req *http.Request) { @@ -544,12 +546,12 @@ func TestSendHeartbeats_NonExistingEntity(t *testing.T) { ctx = log.ToContext(ctx, logger) - f, err := os.CreateTemp(tmpDir, "") + offlineQueueFile, err := os.CreateTemp(tmpDir, "") require.NoError(t, err) - defer f.Close() + defer offlineQueueFile.Close() - err = cmdheartbeat.SendHeartbeats(ctx, v, f.Name()) + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) output, err := io.ReadAll(logFile) @@ -1061,6 +1063,8 @@ func TestSendHeartbeats_ObfuscateProject(t *testing.T) { offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) + defer offlineQueueFile.Close() + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) @@ -1150,6 +1154,8 @@ func TestSendHeartbeats_ObfuscateProjectNotBranch(t *testing.T) { offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") require.NoError(t, err) + defer offlineQueueFile.Close() + err = cmdheartbeat.SendHeartbeats(ctx, v, offlineQueueFile.Name()) require.NoError(t, err) diff --git a/cmd/offline/offline_test.go b/cmd/offline/offline_test.go new file mode 100644 index 00000000..b308760a --- /dev/null +++ b/cmd/offline/offline_test.go @@ -0,0 +1,113 @@ +package offline_test + +import ( + "context" + "encoding/json" + "os" + "testing" + + cmdoffline "github.com/wakatime/wakatime-cli/cmd/offline" + "github.com/wakatime/wakatime-cli/pkg/heartbeat" + "github.com/wakatime/wakatime-cli/pkg/offline" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSaveHeartbeats(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer tmpFile.Close() + + offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer offlineQueueFile.Close() + + ctx := context.Background() + + v := viper.New() + v.Set("config", tmpFile.Name()) + v.Set("category", "debugging") + v.Set("cursorpos", 42) + v.Set("entity", "testdata/main.go") + v.Set("entity-type", "file") + v.Set("key", "00000000-0000-4000-8000-000000000000") + v.Set("language", "Go") + v.Set("alternate-language", "Golang") + v.Set("hide-branch-names", true) + v.Set("project", "wakatime-cli") + v.Set("lineno", 13) + v.Set("time", 1585598059.1) + v.Set("timeout", 5) + v.Set("write", true) + + err = cmdoffline.SaveHeartbeats(ctx, v, nil, offlineQueueFile.Name()) + require.NoError(t, err) + + offlineCount, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) + require.NoError(t, err) + + assert.Equal(t, 1, offlineCount) +} + +func TestSaveHeartbeats_ExtraHeartbeats(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer tmpFile.Close() + + offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer offlineQueueFile.Close() + + ctx := context.Background() + + data, err := os.ReadFile("testdata/extra_heartbeats.json") + require.NoError(t, err) + + var hh []heartbeat.Heartbeat + + err = json.Unmarshal(data, &hh) + require.NoError(t, err) + + v := viper.New() + v.Set("config", tmpFile.Name()) + v.Set("entity", "testdata/main.go") + v.Set("key", "00000000-0000-4000-8000-000000000000") + + err = cmdoffline.SaveHeartbeats(ctx, v, hh, offlineQueueFile.Name()) + require.NoError(t, err) + + offlineCount, err := offline.CountHeartbeats(ctx, offlineQueueFile.Name()) + require.NoError(t, err) + + assert.Equal(t, 25, offlineCount) +} + +func TestSaveHeartbeats_OfflineDisabled(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer tmpFile.Close() + + offlineQueueFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer offlineQueueFile.Close() + + ctx := context.Background() + + v := viper.New() + v.Set("config", tmpFile.Name()) + v.Set("disable-offline", true) + v.Set("entity", "testdata/main.go") + v.Set("key", "00000000-0000-4000-8000-000000000000") + + err = cmdoffline.SaveHeartbeats(ctx, v, nil, offlineQueueFile.Name()) + + assert.EqualError(t, err, "saving to offline db disabled") +} diff --git a/cmd/offline/testdata/extra_heartbeats.json b/cmd/offline/testdata/extra_heartbeats.json new file mode 100644 index 00000000..6e043da1 --- /dev/null +++ b/cmd/offline/testdata/extra_heartbeats.json @@ -0,0 +1 @@ +[ { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598059 }, { "alternate_language": "Py", "category": "debugging", "cursorpos": null, "entity": "testdata/main.py", "is_write": null, "language": "Python", "lineno": null, "lines": null, "project": "wakatime-cli", "type": "file", "timestamp": 1585598060 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598061 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598062 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598063 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598064 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598065 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598066 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598067 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598068 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598069 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598070 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598071 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598072 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598073 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598074 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598075 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598076 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598077 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598078 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598079 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598080 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598081 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598082 }, { "alternate_language": "Golang", "alternate_project": "billing", "category": "coding", "cursorpos": 12, "entity": "testdata/main.go", "entity_type": "file", "is_write": true, "language": "Go", "lineno": 42, "lines": 45, "project": "wakatime-cli", "time": 1585598083 } ] \ No newline at end of file diff --git a/cmd/offline/testdata/localfile.go b/cmd/offline/testdata/localfile.go new file mode 100644 index 00000000..5626cbe7 --- /dev/null +++ b/cmd/offline/testdata/localfile.go @@ -0,0 +1,3 @@ +hello +hello +world diff --git a/cmd/offline/testdata/main.go b/cmd/offline/testdata/main.go new file mode 100644 index 00000000..c6e84062 --- /dev/null +++ b/cmd/offline/testdata/main.go @@ -0,0 +1,11 @@ +package main + +import ( + "fmt" + "os" +) + +func main() { + fmt.Println("hello world") + os.Exit(0) +} diff --git a/cmd/offline/testdata/main.py b/cmd/offline/testdata/main.py new file mode 100644 index 00000000..56ff0925 --- /dev/null +++ b/cmd/offline/testdata/main.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# vim: set filetype=python + +from __future__ import print_statement + +import os, sys +from flask import session, render_template +import simplejson as json + + +class MyClass(object): + """this class + """ + + def method1(self): + a = 1 + 2 + b = 'hello world!' + for x in y: + print(x) + raise Exception() diff --git a/cmd/run_internal_test.go b/cmd/run_internal_test.go index 28885ef0..728ae750 100644 --- a/cmd/run_internal_test.go +++ b/cmd/run_internal_test.go @@ -46,6 +46,166 @@ func TestRunCmd_Err(t *testing.T) { assert.Equal(t, exitcode.ErrGeneric, err.(exitcode.Err).Code) } +func TestRunCmd_Panic(t *testing.T) { + testServerURL, router, tearDown := setupTestServer() + defer tearDown() + + version.OS = "some os" + version.Arch = "some architecture" + version.Version = "some version" + + router.HandleFunc("/plugins/errors", func(w http.ResponseWriter, req *http.Request) { + // check request + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) + assert.Equal(t, []string{"application/json"}, req.Header["Content-Type"]) + + expectedBodyTpl, err := os.ReadFile("testdata/diagnostics_request_template.json") + require.NoError(t, err) + + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + + var diagnostics struct { + Architecture string `json:"architecture"` + CliVersion string `json:"cli_version"` + Editor string `json:"editor"` + Logs string `json:"logs"` + OriginalError string `json:"error_message"` + Platform string `json:"platform"` + Plugin string `json:"plugin"` + Stack string `json:"stacktrace"` + } + + err = json.Unmarshal(body, &diagnostics) + require.NoError(t, err) + + expectedBodyStr := fmt.Sprintf( + string(expectedBodyTpl), + jsonEscape(t, diagnostics.OriginalError), + jsonEscape(t, diagnostics.Logs), + jsonEscape(t, diagnostics.Stack), + ) + + assert.JSONEq(t, expectedBodyStr, string(body)) + + // send response + w.WriteHeader(http.StatusCreated) + }) + + logFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer logFile.Close() + + ctx := context.Background() + + v := viper.New() + v.Set("api-url", testServerURL) + v.Set("log-file", logFile.Name()) + + logger, err := SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) + + err = runCmd(ctx, v, false, false, func(_ context.Context, _ *viper.Viper) (int, error) { + panic("fail") + }) + + var errexitcode exitcode.Err + + require.ErrorAs(t, err, &errexitcode) + assert.Equal(t, exitcode.ErrGeneric, err.(exitcode.Err).Code) + + output, err := io.ReadAll(logFile) + require.NoError(t, err) + + assert.Contains(t, string(output), "panicked") +} + +func TestRunCmd_Panic_Verbose(t *testing.T) { + testServerURL, router, tearDown := setupTestServer() + defer tearDown() + + version.OS = "some os" + version.Arch = "some architecture" + version.Version = "some version" + + router.HandleFunc("/plugins/errors", func(w http.ResponseWriter, req *http.Request) { + // check request + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, []string{"Basic MDAwMDAwMDAtMDAwMC00MDAwLTgwMDAtMDAwMDAwMDAwMDAw"}, req.Header["Authorization"]) + assert.Equal(t, []string{"application/json"}, req.Header["Content-Type"]) + + expectedBodyTpl, err := os.ReadFile("testdata/diagnostics_request_template.json") + require.NoError(t, err) + + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + + var diagnostics struct { + Architecture string `json:"architecture"` + CliVersion string `json:"cli_version"` + Editor string `json:"editor"` + Logs string `json:"logs"` + OriginalError string `json:"error_message"` + Platform string `json:"platform"` + Plugin string `json:"plugin"` + Stack string `json:"stacktrace"` + } + + err = json.Unmarshal(body, &diagnostics) + require.NoError(t, err) + + expectedBodyStr := fmt.Sprintf( + string(expectedBodyTpl), + jsonEscape(t, diagnostics.OriginalError), + jsonEscape(t, diagnostics.Logs), + jsonEscape(t, diagnostics.Stack), + ) + + assert.JSONEq(t, expectedBodyStr, string(body)) + + // send response + w.WriteHeader(http.StatusCreated) + }) + + logFile, err := os.CreateTemp(t.TempDir(), "") + require.NoError(t, err) + + defer logFile.Close() + + ctx := context.Background() + + v := viper.New() + v.Set("api-url", testServerURL) + v.Set("log-file", logFile.Name()) + + logger, err := SetupLogging(ctx, v) + require.NoError(t, err) + + defer logger.Flush() + + ctx = log.ToContext(ctx, logger) + + err = runCmd(ctx, v, true, false, func(_ context.Context, _ *viper.Viper) (int, error) { + panic("fail") + }) + + var errexitcode exitcode.Err + + require.ErrorAs(t, err, &errexitcode) + assert.Equal(t, exitcode.ErrGeneric, err.(exitcode.Err).Code) + + output, err := io.ReadAll(logFile) + require.NoError(t, err) + + assert.Contains(t, string(output), "panicked") +} + func TestRunCmd_ErrOfflineEnqueue(t *testing.T) { testServerURL, router, tearDown := setupTestServer() defer tearDown() From b74e98790f97e8e4f71fced6197c9bd7d310b1ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Henrique=20Guard=C3=A3o=20Gandarez?= Date: Sat, 9 Nov 2024 11:31:09 -0300 Subject: [PATCH 6/6] Bump pipeline dependencies to latest version --- .github/workflows/on_push.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/on_push.yml b/.github/workflows/on_push.yml index ef15fd91..b755dc4e 100644 --- a/.github/workflows/on_push.yml +++ b/.github/workflows/on_push.yml @@ -34,7 +34,7 @@ jobs: run: make test - name: Linter - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: version: latest skip-cache: true @@ -94,7 +94,7 @@ jobs: run: make test - name: Linter - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: version: latest skip-cache: true @@ -145,7 +145,7 @@ jobs: run: make test - name: Linter - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: version: latest skip-cache: true @@ -574,7 +574,7 @@ jobs: path: build/ - name: Import Code-Signing Certificates - uses: Apple-Actions/import-codesign-certs@v1 + uses: Apple-Actions/import-codesign-certs@v3 with: # The certificates in a PKCS12 file encoded as a base64 string p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} @@ -644,7 +644,7 @@ jobs: # Run only for develop branch if: ${{ github.ref == 'refs/heads/develop' }} name: Changelog for develop - uses: gandarez/changelog-action@v1.2.0 + uses: gandarez/changelog-action@v1.4.0 id: changelog-develop with: current_tag: ${{ github.sha }} @@ -655,7 +655,7 @@ jobs: # Run only for release branch if: ${{ github.ref == 'refs/heads/release' }} name: Get related pull request - uses: 8BitJonny/gh-get-current-pr@v2.2.0 + uses: 8BitJonny/gh-get-current-pr@3.0.0 id: changelog-release with: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -681,7 +681,7 @@ jobs: run: ./bin/prepare_assets.sh - name: "Create release" - uses: softprops/action-gh-release@master + uses: softprops/action-gh-release@v2 with: name: ${{ needs.version.outputs.semver_tag }} tag_name: ${{ needs.version.outputs.semver_tag }}