diff --git a/metadata.go b/metadata.go index 6b2cd81..93b13ec 100644 --- a/metadata.go +++ b/metadata.go @@ -1,6 +1,7 @@ package druid import ( + "context" _ "embed" "errors" "strings" @@ -64,6 +65,8 @@ func fillDataSourceName(in string, ds string) string { // AwaitDataSourceAvailable awaits for a datasource to be visible in druid table listing. func (md *MetadataService) AwaitDataSourceAvailable(dataSourceName string) error { + ctx, cancel := context.WithTimeout(context.Background(), md.awaitTimeout) + defer cancel() ticker := time.NewTicker(md.tickerDuration) defer ticker.Stop() q := query. @@ -81,7 +84,7 @@ func (md *MetadataService) AwaitDataSourceAvailable(dataSourceName string) error if len(res) >= 1 && res[0].Cnt == 1 { return nil } - case <-time.After(md.awaitTimeout): + case <-ctx.Done(): return errors.New("AwaitDataSourceAvailable timeout") } } @@ -92,6 +95,8 @@ var datasourceRecordsQuery string // AwaitRecordsCount awaits for specific recordsCount in a given datasource. func (md *MetadataService) AwaitRecordsCount(dataSourceName string, recordsCount int) error { + ctx, cancel := context.WithTimeout(context.Background(), md.awaitTimeout) + defer cancel() ticker := time.NewTicker(md.tickerDuration) defer ticker.Stop() q := query.NewSQL() @@ -108,7 +113,7 @@ func (md *MetadataService) AwaitRecordsCount(dataSourceName string, recordsCount if len(res) >= 1 && res[0].Cnt == recordsCount { return nil } - case <-time.After(md.awaitTimeout): + case <-ctx.Done(): return errors.New("AwaitRecordsCount timeout") } } diff --git a/tasks_test.go b/tasks_test.go index a2934cd..64eaa89 100644 --- a/tasks_test.go +++ b/tasks_test.go @@ -58,6 +58,9 @@ func TriggerIngestionTask[T any](d *Client, dataSourceName string, entries []T) // AwaitTaskCompletion waits for the task to complete. Function timeouts with an error after awaitTimeout nanoseconds. func AwaitTaskCompletion(client *Client, taskID string, awaitTimeout time.Duration, tickerDuration time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), awaitTimeout) + defer cancel() + ticker := time.NewTicker(tickerDuration) defer ticker.Stop() L: @@ -73,7 +76,7 @@ L: continue } break L - case <-time.After(awaitTimeout): + case <-ctx.Done(): return errors.New("AwaitTaskRunning timeout") } } @@ -82,6 +85,9 @@ L: // AwaitTaskStatus waits for the druid task status for the maximum of awaitTimeout duration, querying druid task API. func AwaitTaskStatus(client *Client, taskID string, status string, awaitTimeout time.Duration, tickerDuration time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), awaitTimeout) + defer cancel() + ticker := time.NewTicker(tickerDuration) defer ticker.Stop() for { @@ -95,7 +101,7 @@ func AwaitTaskStatus(client *Client, taskID string, status string, awaitTimeout if res.Status.Status == status { return nil } - case <-time.After(awaitTimeout): + case <-ctx.Done(): return errors.New("AwaitTaskRunning timeout") } }