Skip to content

Commit

Permalink
refactor context (cscli, pkg/database) (#3071)
Browse files Browse the repository at this point in the history
* cscli: helper require.DBClient()

* refactor pkg/database: explicit context to dbclient constructor

* lint
  • Loading branch information
mmetc authored Jun 11, 2024
1 parent 24687e9 commit bd4540b
Show file tree
Hide file tree
Showing 15 changed files with 97 additions and 49 deletions.
15 changes: 11 additions & 4 deletions cmd/crowdsec-cli/alerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
)
Expand Down Expand Up @@ -378,28 +377,35 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil {
return err
}

if ActiveDecision != nil {
alertDeleteFilter.ActiveDecisionEquals = ActiveDecision
}

if *alertDeleteFilter.ScopeEquals == "" {
alertDeleteFilter.ScopeEquals = nil
}

if *alertDeleteFilter.ValueEquals == "" {
alertDeleteFilter.ValueEquals = nil
}

if *alertDeleteFilter.ScenarioEquals == "" {
alertDeleteFilter.ScenarioEquals = nil
}

if *alertDeleteFilter.IPEquals == "" {
alertDeleteFilter.IPEquals = nil
}

if *alertDeleteFilter.RangeEquals == "" {
alertDeleteFilter.RangeEquals = nil
}

if contained != nil && *contained {
alertDeleteFilter.Contains = new(bool)
}

limit := 0
alertDeleteFilter.Limit = &limit
} else {
Expand All @@ -419,6 +425,7 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ
return fmt.Errorf("unable to delete alert: %w", err)
}
}

log.Infof("%s alert(s) deleted", alerts.NbDeleted)

return nil
Expand Down Expand Up @@ -558,14 +565,14 @@ func (cli *cliAlerts) NewFlushCmd() *cobra.Command {
/!\ This command can be used only on the same machine than the local API`,
Example: `cscli alerts flush --max-items 1000 --max-age 7d`,
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
cfg := cli.cfg()
if err := require.LAPI(cfg); err != nil {
return err
}
db, err := database.NewClient(cfg.DbConfig)
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil {
return fmt.Errorf("unable to create new database client: %w", err)
return err
}
log.Info("Flushing alerts. !! This may take a long time !!")
err = db.FlushAlerts(maxAge, maxItems)
Expand Down
6 changes: 3 additions & 3 deletions cmd/crowdsec-cli/bouncers.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Note: This command requires database direct access, so is intended to be run on
Args: cobra.MinimumNArgs(1),
Aliases: []string{"bouncer"},
DisableAutoGenTag: true,
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
var err error

cfg := cli.cfg()
Expand All @@ -66,9 +66,9 @@ Note: This command requires database direct access, so is intended to be run on
return err
}

cli.db, err = database.NewClient(cfg.DbConfig)
cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil {
return fmt.Errorf("can't connect to the database: %w", err)
return err
}

return nil
Expand Down
6 changes: 3 additions & 3 deletions cmd/crowdsec-cli/machines.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ Note: This command requires database direct access, so is intended to be run on
Example: `cscli machines [action]`,
DisableAutoGenTag: true,
Aliases: []string{"machine"},
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
var err error
if err = require.LAPI(cli.cfg()); err != nil {
return err
}
cli.db, err = database.NewClient(cli.cfg().DbConfig)
cli.db, err = require.DBClient(cmd.Context(), cli.cfg().DbConfig)
if err != nil {
return fmt.Errorf("unable to create new database client: %w", err)
return err
}

return nil
Expand Down
13 changes: 6 additions & 7 deletions cmd/crowdsec-cli/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
"github.com/crowdsecurity/crowdsec/pkg/database"
)

type cliPapi struct {
Expand Down Expand Up @@ -56,12 +55,12 @@ func (cli *cliPapi) NewStatusCmd() *cobra.Command {
Short: "Get status of the Polling API",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
var err error
cfg := cli.cfg()
db, err := database.NewClient(cfg.DbConfig)
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil {
return fmt.Errorf("unable to initialize database client: %w", err)
return err
}

apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
Expand Down Expand Up @@ -105,14 +104,14 @@ func (cli *cliPapi) NewSyncCmd() *cobra.Command {
Short: "Sync with the Polling API, pulling all non-expired orders for the instance",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
var err error
cfg := cli.cfg()
t := tomb.Tomb{}

db, err := database.NewClient(cfg.DbConfig)
db, err := require.DBClient(cmd.Context(), cfg.DbConfig)
if err != nil {
return fmt.Errorf("unable to initialize database client: %w", err)
return err
}

apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists)
Expand Down
10 changes: 10 additions & 0 deletions cmd/crowdsec-cli/require/require.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
"github.com/crowdsecurity/crowdsec/pkg/database"
)

func LAPI(c *csconfig.Config) error {
Expand Down Expand Up @@ -48,6 +49,15 @@ func CAPIRegistered(c *csconfig.Config) error {
return nil
}

func DBClient(ctx context.Context, dbcfg *csconfig.DatabaseCfg) (*database.Client, error) {
db, err := database.NewClient(ctx, dbcfg)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}

return db, nil
}

func DB(c *csconfig.Config) error {
if err := c.LoadDBConfig(true); err != nil {
return fmt.Errorf("this command requires direct database access (must be run on the local API machine): %w", err)
Expand Down
4 changes: 2 additions & 2 deletions cmd/crowdsec-cli/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,9 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
w := bytes.NewBuffer(nil)
zipWriter := zip.NewWriter(w)

db, err := database.NewClient(cfg.DbConfig)
db, err := require.DBClient(ctx, cfg.DbConfig)
if err != nil {
log.Warnf("Could not connect to database: %s", err)
log.Warn(err)
}

if err = cfg.LoadAPIServer(true); err != nil {
Expand Down
6 changes: 4 additions & 2 deletions cmd/crowdsec/run_in_svc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package main

import (
"context"
"fmt"
"runtime/pprof"

Expand Down Expand Up @@ -41,9 +42,10 @@ func StartRunSvc() error {

var err error

if cConfig.DbConfig != nil {
dbClient, err = database.NewClient(cConfig.DbConfig)
ctx := context.TODO()

if cConfig.DbConfig != nil {
dbClient, err = database.NewClient(ctx, cConfig.DbConfig)
if err != nil {
return fmt.Errorf("unable to create database client: %w", err)
}
Expand Down
5 changes: 4 additions & 1 deletion cmd/crowdsec/run_in_svc_windows.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"runtime/pprof"

Expand Down Expand Up @@ -80,8 +81,10 @@ func WindowsRun() error {
var dbClient *database.Client
var err error

ctx := context.TODO()

if cConfig.DbConfig != nil {
dbClient, err = database.NewClient(cConfig.DbConfig)
dbClient, err = database.NewClient(ctx, cConfig.DbConfig)

if err != nil {
return fmt.Errorf("unable to create database client: %w", err)
Expand Down
5 changes: 4 additions & 1 deletion cmd/crowdsec/serve.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"os"
"os/signal"
Expand Down Expand Up @@ -322,8 +323,10 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
crowdsecTomb = tomb.Tomb{}
pluginTomb = tomb.Tomb{}

ctx := context.TODO()

if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil {
dbClient, err := database.NewClient(cConfig.API.Server.DbConfig)
dbClient, err := database.NewClient(ctx, cConfig.API.Server.DbConfig)
if err != nil {
return fmt.Errorf("failed to get database client: %w", err)
}
Expand Down
22 changes: 14 additions & 8 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ import (
func getDBClient(t *testing.T) *database.Client {
t.Helper()

ctx := context.Background()

dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err)
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
dbClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{
Type: "sqlite",
DbName: "crowdsec",
DbPath: dbPath.Name(),
Expand All @@ -56,7 +58,7 @@ func getAPIC(t *testing.T) *apic {

return &apic{
AlertsAddChan: make(chan []*models.Alert),
//DecisionDeleteChan: make(chan []*models.Decision),
// DecisionDeleteChan: make(chan []*models.Decision),
dbClient: dbClient,
mu: sync.Mutex{},
startup: true,
Expand Down Expand Up @@ -176,10 +178,11 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
}

scenarios, err := api.FetchScenariosListFromDB()
require.NoError(t, err)

for machineID := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
}
require.NoError(t, err)

assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
})
Expand Down Expand Up @@ -234,6 +237,7 @@ func TestNewAPIC(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
setConfig()
httpmock.Activate()

defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("POST", "http://foobar/v3/watchers/login", httpmock.NewBytesResponder(
200, jsonMarshalX(
Expand Down Expand Up @@ -353,6 +357,7 @@ func TestAPICGetMetrics(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
apiClient := getAPIC(t)
cleanUp(apiClient)

for i, machineID := range tc.machineIDs {
apiClient.dbClient.Ent.Machine.Create().
SetMachineId(machineID).
Expand Down Expand Up @@ -548,7 +553,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {

func TestAPICWhitelists(t *testing.T) {
api := getAPIC(t)
//one whitelist on IP, one on CIDR
// one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{}
api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4"))

Expand Down Expand Up @@ -593,7 +598,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Value: ptr.Of("13.2.3.4"), //wl by cidr
Value: ptr.Of("13.2.3.4"), // wl by cidr
Duration: ptr.Of("24h"),
},
},
Expand All @@ -614,7 +619,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Value: ptr.Of("13.2.3.5"), //wl by cidr
Value: ptr.Of("13.2.3.5"), // wl by cidr
Duration: ptr.Of("24h"),
},
},
Expand All @@ -634,7 +639,7 @@ func TestAPICWhitelists(t *testing.T) {
Scope: ptr.Of("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Value: ptr.Of("9.2.3.4"), //wl by ip
Value: ptr.Of("9.2.3.4"), // wl by ip
Duration: ptr.Of("24h"),
},
},
Expand Down Expand Up @@ -685,7 +690,7 @@ func TestAPICWhitelists(t *testing.T) {
err = api.PullTop(false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5) //2 from FIRE + 2 from bl + 1 existing
assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
assertTotalValidDecisionCount(t, api.dbClient, 4)
assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list.
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
Expand Down Expand Up @@ -1103,6 +1108,7 @@ func TestAPICPush(t *testing.T) {

httpmock.Activate()
defer httpmock.DeactivateAndReset()

apic, err := apiclient.NewDefaultClient(
url,
"/api",
Expand Down
6 changes: 4 additions & 2 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
var flushScheduler *gocron.Scheduler

dbClient, err := database.NewClient(config.DbConfig)
ctx := context.TODO()

dbClient, err := database.NewClient(ctx, config.DbConfig)
if err != nil {
return nil, fmt.Errorf("unable to init database client: %w", err)
}
Expand Down Expand Up @@ -227,7 +229,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {

controller := &controllers.Controller{
DBClient: dbClient,
Ectx: context.Background(),
Ectx: ctx,
Router: router,
Profiles: config.Profiles,
Log: clog,
Expand Down
Loading

0 comments on commit bd4540b

Please sign in to comment.