From 7fd4b99b17c29c360b9b1fea6894a009f5f38347 Mon Sep 17 00:00:00 2001 From: Stan Rosenberg Date: Mon, 20 Mar 2023 00:11:11 -0400 Subject: [PATCH] roachprod: implement --dry-run and --verbose Debugging roachprod in a remote (user) environment can be challenging. On several occassions, simply printing all executed (cloud) commands was sufficient to reproduce and root cause an issue. This change implements two mutually exclusive modes, namely `dry-run` and `verbose`. The latter merely logs every command before it's executed. The former suppresses execution of any command which can modify the (cloud) infrastructure; i.e., `dry-run` executes read-only commands; it logs every command, executed or not. Also note that `dry-run` leaves temporary files like VM startup scripts on disk, so that they can be readily examined. Thus, `dry-run` is primarily for dumping _all_ commands without the risk of modifying anything; `verbose` is primarily for live debugging. Release note: None Epic: none Fixes: #99091 --- pkg/BUILD.bazel | 3 + pkg/cmd/roachprod/cli/commands.go | 6 +- pkg/cmd/roachprod/cli/flags.go | 6 +- pkg/cmd/roachprod/cli/util.go | 9 +- pkg/roachprod/BUILD.bazel | 1 + pkg/roachprod/cloud/cluster_cloud.go | 2 +- pkg/roachprod/cloud/gc.go | 4 +- pkg/roachprod/clusters_cache.go | 21 +- pkg/roachprod/config/config.go | 4 + pkg/roachprod/install/BUILD.bazel | 1 + pkg/roachprod/install/cluster_synced.go | 18 +- pkg/roachprod/install/cockroach.go | 38 +-- pkg/roachprod/install/expander.go | 6 +- pkg/roachprod/install/services.go | 20 +- pkg/roachprod/install/services_test.go | 10 +- pkg/roachprod/install/session.go | 7 +- pkg/roachprod/logger/BUILD.bazel | 2 + pkg/roachprod/logger/log.go | 4 +- pkg/roachprod/logger/log_redirect_test.go | 41 +-- pkg/roachprod/logger/test_utils.go | 52 +++ pkg/roachprod/roachprod.go | 31 +- pkg/roachprod/vm/aws/BUILD.bazel | 1 + pkg/roachprod/vm/aws/aws.go | 70 ++-- pkg/roachprod/vm/aws/keys.go | 7 +- pkg/roachprod/vm/aws/support.go | 41 --- pkg/roachprod/vm/cli/BUILD.bazel | 25 ++ pkg/roachprod/vm/cli/cli.go | 140 ++++++++ pkg/roachprod/vm/cli/cli_test.go | 113 +++++++ pkg/roachprod/vm/dns.go | 11 +- pkg/roachprod/vm/gce/BUILD.bazel | 1 + pkg/roachprod/vm/gce/dns.go | 81 ++--- pkg/roachprod/vm/gce/gcloud.go | 320 ++++++++----------- pkg/roachprod/vm/gce/testutils/BUILD.bazel | 2 + pkg/roachprod/vm/gce/testutils/dns_server.go | 18 +- pkg/roachprod/vm/gce/utils.go | 14 +- pkg/roachprod/vm/local/dns.go | 19 +- pkg/roachprod/vm/local/dns_test.go | 6 +- pkg/roachprod/vm/local/local.go | 39 ++- 38 files changed, 754 insertions(+), 440 deletions(-) create mode 100644 pkg/roachprod/logger/test_utils.go create mode 100644 pkg/roachprod/vm/cli/BUILD.bazel create mode 100644 pkg/roachprod/vm/cli/cli.go create mode 100644 pkg/roachprod/vm/cli/cli_test.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 0f051e1801e6..0fb12718efeb 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -312,6 +312,7 @@ ALL_TESTS = [ "//pkg/roachprod/prometheus:prometheus_test", "//pkg/roachprod/promhelperclient:promhelperclient_test", "//pkg/roachprod/ssh:ssh_test", + "//pkg/roachprod/vm/cli:cli_test", "//pkg/roachprod/vm/gce:gce_test", "//pkg/roachprod/vm/local:local_test", "//pkg/roachprod/vm:vm_test", @@ -1634,6 +1635,8 @@ GO_TARGETS = [ "//pkg/roachprod/vm/aws/terraformgen:terraformgen_lib", "//pkg/roachprod/vm/aws:aws", "//pkg/roachprod/vm/azure:azure", + "//pkg/roachprod/vm/cli:cli", + "//pkg/roachprod/vm/cli:cli_test", "//pkg/roachprod/vm/flagstub:flagstub", "//pkg/roachprod/vm/gce/testutils:testutils", "//pkg/roachprod/vm/gce:gce", diff --git a/pkg/cmd/roachprod/cli/commands.go b/pkg/cmd/roachprod/cli/commands.go index 9e7d26ad3b69..91ff1f051d15 100644 --- a/pkg/cmd/roachprod/cli/commands.go +++ b/pkg/cmd/roachprod/cli/commands.go @@ -1117,7 +1117,7 @@ func buildSSHKeysListCmd() *cobra.Command { Use: "list", Short: "list every SSH public key installed on clusters managed by roachprod", Run: wrap(func(cmd *cobra.Command, args []string) error { - authorizedKeys, err := gce.GetUserAuthorizedKeys() + authorizedKeys, err := gce.GetUserAuthorizedKeys(config.Logger) if err != nil { return err } @@ -1151,7 +1151,7 @@ func buildSSHKeysAddCmd() *cobra.Command { } fmt.Printf("Adding new public key for user %s...\n", ak.User) - return gce.AddUserAuthorizedKey(ak) + return gce.AddUserAuthorizedKey(config.Logger, ak) }), } sshKeysAddCmd.Flags().StringVar(&sshKeyUser, "user", config.OSUser.Username, @@ -1168,7 +1168,7 @@ func buildSSHKeysRemoveCmd() *cobra.Command { Run: wrap(func(cmd *cobra.Command, args []string) error { user := args[0] - existingKeys, err := gce.GetUserAuthorizedKeys() + existingKeys, err := gce.GetUserAuthorizedKeys(config.Logger) if err != nil { return fmt.Errorf("failed to fetch existing keys: %w", err) } diff --git a/pkg/cmd/roachprod/cli/flags.go b/pkg/cmd/roachprod/cli/flags.go index 0fd3883f39c1..c2060429c730 100644 --- a/pkg/cmd/roachprod/cli/flags.go +++ b/pkg/cmd/roachprod/cli/flags.go @@ -121,6 +121,10 @@ func initRootCmdFlags(rootCmd *cobra.Command) { "use-shared-user", true, fmt.Sprintf("use the shared user %q for ssh rather than your user %q", config.SharedUser, config.OSUser.Username)) + rootCmd.PersistentFlags().BoolVarP(&config.DryRun, "dry-run", "d", + false, "dry-run mode (log all commands & execute read-only ones; i.e., no infra changes)") + rootCmd.PersistentFlags().BoolVarP(&config.Verbose, "verbose", "v", + false, "verbose mode (log all executed commands)") } func initCreateCmdFlags(createCmd *cobra.Command) { @@ -201,7 +205,7 @@ func initListCmdFlags(listCmd *cobra.Command) { "Show cost estimates", ) listCmd.Flags().BoolVarP(&listDetails, - "details", "d", false, "Show cluster details") + "details", "", false, "Show cluster details") listCmd.Flags().BoolVar(&listJSON, "json", false, "Show cluster specs in a json format") listCmd.Flags().BoolVarP(&listMine, diff --git a/pkg/cmd/roachprod/cli/util.go b/pkg/cmd/roachprod/cli/util.go index a62e31a8bac3..c8582c238328 100644 --- a/pkg/cmd/roachprod/cli/util.go +++ b/pkg/cmd/roachprod/cli/util.go @@ -12,6 +12,7 @@ import ( "text/tabwriter" "time" + "github.com/cockroachdb/cockroach/pkg/roachprod/config" rperrors "github.com/cockroachdb/cockroach/pkg/roachprod/errors" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/gce" @@ -202,7 +203,7 @@ Node specification func ValidateAndConfigure(cmd *cobra.Command, args []string) { // Skip validation for commands that are self-sufficient. switch cmd.Name() { - case "help", "version", "list": + case "help", "version": return } @@ -237,4 +238,10 @@ func ValidateAndConfigure(cmd *cobra.Command, args []string) { } providersSet[p] = struct{}{} } + if config.DryRun { + if config.Verbose { + printErrAndExit(fmt.Errorf("--verbose and --dry-run are mutually exclusive")) + } + config.Logger.Printf("Enabling [***Experimental***] --dry-run mode. No infra changes will be made!") + } } diff --git a/pkg/roachprod/BUILD.bazel b/pkg/roachprod/BUILD.bazel index 203bbdf08710..072365a1fa8c 100644 --- a/pkg/roachprod/BUILD.bazel +++ b/pkg/roachprod/BUILD.bazel @@ -27,6 +27,7 @@ go_library( "//pkg/roachprod/vm", "//pkg/roachprod/vm/aws", "//pkg/roachprod/vm/azure", + "//pkg/roachprod/vm/cli", "//pkg/roachprod/vm/gce", "//pkg/roachprod/vm/local", "//pkg/server/debug/replay", diff --git a/pkg/roachprod/cloud/cluster_cloud.go b/pkg/roachprod/cloud/cluster_cloud.go index cd6043890b21..88e88c7fdd4a 100644 --- a/pkg/roachprod/cloud/cluster_cloud.go +++ b/pkg/roachprod/cloud/cluster_cloud.go @@ -433,7 +433,7 @@ func DestroyCluster(l *logger.Logger, c *Cluster) error { // and clean-up entries prematurely. stopSpinner = ui.NewDefaultSpinner(l, "Destroying DNS entries").Start() dnsErr := vm.FanOutDNS(c.VMs, func(p vm.DNSProvider, vms vm.List) error { - return p.DeleteRecordsBySubdomain(context.Background(), c.Name) + return p.DeleteRecordsBySubdomain(context.Background(), l, c.Name) }) stopSpinner() diff --git a/pkg/roachprod/cloud/gc.go b/pkg/roachprod/cloud/gc.go index 15596dd7452c..cb856024e830 100644 --- a/pkg/roachprod/cloud/gc.go +++ b/pkg/roachprod/cloud/gc.go @@ -518,7 +518,7 @@ func GCDNS(l *logger.Logger, cloud *Cloud, dryrun bool) error { if !ok { continue } - records, err := p.ListRecords(ctx) + records, err := p.ListRecords(ctx, l) if err != nil { return err } @@ -541,7 +541,7 @@ func GCDNS(l *logger.Logger, cloud *Cloud, dryrun bool) error { sort.Strings(recordNames) if err := destroyResource(dryrun, func() error { - return p.DeleteRecordsByName(ctx, recordNames...) + return p.DeleteRecordsByName(ctx, l, recordNames...) }); err != nil { return err } diff --git a/pkg/roachprod/clusters_cache.go b/pkg/roachprod/clusters_cache.go index e0610a2e8046..11898263a4cd 100644 --- a/pkg/roachprod/clusters_cache.go +++ b/pkg/roachprod/clusters_cache.go @@ -86,10 +86,20 @@ func saveCluster(l *logger.Logger, c *cloud.Cluster) error { err = errors.CombineErrors(err, tmpFile.Sync()) err = errors.CombineErrors(err, tmpFile.Close()) if err == nil { - err = os.Rename(tmpFile.Name(), filename) + if config.DryRun || config.Verbose { + l.Printf("exec: mv %s %s", tmpFile.Name(), filename) + } + if !config.DryRun { + err = os.Rename(tmpFile.Name(), filename) + } } if err != nil { - _ = os.Remove(tmpFile.Name()) + if config.DryRun || config.Verbose { + l.Printf("exec: rm %s", tmpFile.Name()) + } + if !config.DryRun { + _ = os.Remove(tmpFile.Name()) + } return err } return nil @@ -258,5 +268,12 @@ func (localVMStorage) SaveCluster(l *logger.Logger, cluster *cloud.Cluster) erro // DeleteCluster is part of the local.VMStorage interface. func (localVMStorage) DeleteCluster(l *logger.Logger, name string) error { path := clusterFilename(name) + if config.DryRun || config.Verbose { + l.Printf("exec: rm %s", path) + } + if config.DryRun { + return nil + } + return os.Remove(path) } diff --git a/pkg/roachprod/config/config.go b/pkg/roachprod/config/config.go index 525e63371773..1151cd80a2b1 100644 --- a/pkg/roachprod/config/config.go +++ b/pkg/roachprod/config/config.go @@ -31,6 +31,10 @@ var ( OSUser *user.User // Quiet is used to disable fancy progress output. Quiet = false + // DryRun disables executing commands which would otherwise cause infra. changes. + DryRun = false + // Verbose enables logging all executed commands. What gets logged is a superset of DryRun. + Verbose = false // The default roachprod logger. // N.B. When roachprod is used via CLI, this logger is used for all output. // When roachprod is used via API (e.g. from roachtest), this logger is used only in the few cases, diff --git a/pkg/roachprod/install/BUILD.bazel b/pkg/roachprod/install/BUILD.bazel index 3d6390bdc97d..b8864a087d43 100644 --- a/pkg/roachprod/install/BUILD.bazel +++ b/pkg/roachprod/install/BUILD.bazel @@ -34,6 +34,7 @@ go_library( "//pkg/roachprod/ui", "//pkg/roachprod/vm", "//pkg/roachprod/vm/aws", + "//pkg/roachprod/vm/cli", "//pkg/roachprod/vm/gce", "//pkg/roachprod/vm/local", "//pkg/testutils", diff --git a/pkg/roachprod/install/cluster_synced.go b/pkg/roachprod/install/cluster_synced.go index 8742449994f4..4f90518fb7ab 100644 --- a/pkg/roachprod/install/cluster_synced.go +++ b/pkg/roachprod/install/cluster_synced.go @@ -32,6 +32,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/ui" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/aws" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/local" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/retry" @@ -438,7 +439,7 @@ func (c *SyncedCluster) Stop( return err } - services, err := c.DiscoverServices(ctx, name, ServiceTypeSQL) + services, err := c.DiscoverServices(ctx, l, name, ServiceTypeSQL) if err != nil { return err } @@ -2339,6 +2340,8 @@ func (c *SyncedCluster) Logs( cmd.Stdout = os.Stdout cmd.Stderr = &stderrBuf + cli.MaybeLogCmd(context.Background(), l, cmd) + if err := cmd.Run(); err != nil { if ctx.Err() != nil { return nil @@ -2379,6 +2382,9 @@ func (c *SyncedCluster) Logs( cmd.Stdout = out var errBuf bytes.Buffer cmd.Stderr = &errBuf + + cli.MaybeLogCmd(ctx, l, cmd) + if err := cmd.Run(); err != nil && ctx.Err() == nil { return errors.Wrapf(err, "failed to run cockroach debug merge-logs:\n%v", errBuf.String()) } @@ -2632,7 +2638,7 @@ func (c *SyncedCluster) pgurls( } m := make(map[Node]string, len(hosts)) for node, host := range hosts { - desc, err := c.DiscoverService(ctx, node, virtualClusterName, ServiceTypeSQL, sqlInstance) + desc, err := c.DiscoverService(ctx, l, node, virtualClusterName, ServiceTypeSQL, sqlInstance) if err != nil { return nil, err } @@ -2667,7 +2673,7 @@ func (c *SyncedCluster) loadBalancerURL( sqlInstance int, auth PGAuthMode, ) (string, error) { - services, err := c.DiscoverServices(ctx, virtualClusterName, ServiceTypeSQL) + services, err := c.DiscoverServices(ctx, l, virtualClusterName, ServiceTypeSQL) if err != nil { return "", err } @@ -2802,6 +2808,8 @@ func scp(ctx context.Context, l *logger.Logger, src, dest string) (*RunResultDet cmd := exec.CommandContext(ctx, args[0], args[1:]...) cmd.WaitDelay = time.Second // make sure the call below returns when the context is canceled + cli.MaybeLogCmd(context.Background(), l, cmd) + out, err := cmd.CombinedOutput() if err != nil { err = rperrors.NewSSHError(errors.Wrapf(err, "~ %s\n%s", strings.Join(args, " "), out)) @@ -2993,10 +3001,10 @@ func (c *SyncedCluster) Init(ctx context.Context, l *logger.Logger, node Node) e // allPublicAddrs returns a string that can be used when starting cockroach to // indicate the location of all nodes in the cluster. -func (c *SyncedCluster) allPublicAddrs(ctx context.Context) (string, error) { +func (c *SyncedCluster) allPublicAddrs(ctx context.Context, l *logger.Logger) (string, error) { var addrs []string for _, node := range c.Nodes { - port, err := c.NodePort(ctx, node, "" /* virtualClusterName */, 0 /* sqlInstance */) + port, err := c.NodePort(ctx, l, node, "" /* virtualClusterName */, 0 /* sqlInstance */) if err != nil { return "", err } diff --git a/pkg/roachprod/install/cockroach.go b/pkg/roachprod/install/cockroach.go index c22e4bef7065..913baa10f1c1 100644 --- a/pkg/roachprod/install/cockroach.go +++ b/pkg/roachprod/install/cockroach.go @@ -269,7 +269,7 @@ func (c *SyncedCluster) allowServiceRegistration() bool { func (c *SyncedCluster) maybeRegisterServices( ctx context.Context, l *logger.Logger, startOpts StartOpts, portFunc FindOpenPortsFunc, ) error { - serviceMap, err := c.MapServices(ctx, startOpts.VirtualClusterName, startOpts.SQLInstance) + serviceMap, err := c.MapServices(ctx, l, startOpts.VirtualClusterName, startOpts.SQLInstance) if err != nil { return err } @@ -296,7 +296,7 @@ func (c *SyncedCluster) maybeRegisterServices( if err != nil { return err } - return c.RegisterServices(ctx, servicesToRegister) + return c.RegisterServices(ctx, l, servicesToRegister) } // servicesWithOpenPortSelection returns services to be registered for @@ -677,9 +677,9 @@ func (c *SyncedCluster) NodeURL( // NodePort returns the system tenant's SQL port for the given node. func (c *SyncedCluster) NodePort( - ctx context.Context, node Node, virtualClusterName string, sqlInstance int, + ctx context.Context, l *logger.Logger, node Node, virtualClusterName string, sqlInstance int, ) (int, error) { - desc, err := c.DiscoverService(ctx, node, virtualClusterName, ServiceTypeSQL, sqlInstance) + desc, err := c.DiscoverService(ctx, l, node, virtualClusterName, ServiceTypeSQL, sqlInstance) if err != nil { return 0, err } @@ -688,9 +688,9 @@ func (c *SyncedCluster) NodePort( // NodeUIPort returns the system tenant's AdminUI port for the given node. func (c *SyncedCluster) NodeUIPort( - ctx context.Context, node Node, virtualClusterName string, sqlInstance int, + ctx context.Context, l *logger.Logger, node Node, virtualClusterName string, sqlInstance int, ) (int, error) { - desc, err := c.DiscoverService(ctx, node, virtualClusterName, ServiceTypeUI, sqlInstance) + desc, err := c.DiscoverService(ctx, l, node, virtualClusterName, ServiceTypeUI, sqlInstance) if err != nil { return 0, err } @@ -715,7 +715,7 @@ func (c *SyncedCluster) ExecOrInteractiveSQL( if len(c.Nodes) != 1 { return fmt.Errorf("invalid number of nodes for interactive sql: %d", len(c.Nodes)) } - desc, err := c.DiscoverService(ctx, c.Nodes[0], virtualClusterName, ServiceTypeSQL, sqlInstance) + desc, err := c.DiscoverService(ctx, l, c.Nodes[0], virtualClusterName, ServiceTypeSQL, sqlInstance) if err != nil { return err } @@ -742,7 +742,7 @@ func (c *SyncedCluster) ExecSQL( display := fmt.Sprintf("%s: executing sql", c.Name) results, _, err := c.ParallelE(ctx, l, WithNodes(nodes).WithDisplay(display).WithFailSlow(), func(ctx context.Context, node Node) (*RunResultDetails, error) { - desc, err := c.DiscoverService(ctx, node, virtualClusterName, ServiceTypeSQL, sqlInstance) + desc, err := c.DiscoverService(ctx, l, node, virtualClusterName, ServiceTypeSQL, sqlInstance) if err != nil { return nil, err } @@ -1001,7 +1001,7 @@ func (c *SyncedCluster) generateStartArgs( instance := startOpts.SQLInstance var sqlPort int if startOpts.Target == StartServiceForVirtualCluster { - desc, err := c.DiscoverService(ctx, node, virtualClusterName, ServiceTypeSQL, instance) + desc, err := c.DiscoverService(ctx, l, node, virtualClusterName, ServiceTypeSQL, instance) if err != nil { return nil, err } @@ -1011,14 +1011,14 @@ func (c *SyncedCluster) generateStartArgs( virtualClusterName = SystemInterfaceName // System interface instance is always 0. instance = 0 - desc, err := c.DiscoverService(ctx, node, virtualClusterName, ServiceTypeSQL, instance) + desc, err := c.DiscoverService(ctx, l, node, virtualClusterName, ServiceTypeSQL, instance) if err != nil { return nil, err } sqlPort = desc.Port args = append(args, fmt.Sprintf("--listen-addr=%s:%d", listenHost, sqlPort)) } - desc, err := c.DiscoverService(ctx, node, virtualClusterName, ServiceTypeUI, instance) + desc, err := c.DiscoverService(ctx, l, node, virtualClusterName, ServiceTypeUI, instance) if err != nil { return nil, err } @@ -1041,7 +1041,7 @@ func (c *SyncedCluster) generateStartArgs( joinTargets := startOpts.GetJoinTargets() addresses := make([]string, len(joinTargets)) for i, joinNode := range startOpts.GetJoinTargets() { - desc, err := c.DiscoverService(ctx, joinNode, SystemInterfaceName, ServiceTypeSQL, 0) + desc, err := c.DiscoverService(ctx, l, joinNode, SystemInterfaceName, ServiceTypeSQL, 0) if err != nil { return nil, err } @@ -1050,7 +1050,7 @@ func (c *SyncedCluster) generateStartArgs( args = append(args, fmt.Sprintf("--join=%s", strings.Join(addresses, ","))) } if startOpts.Target == StartServiceForVirtualCluster { - storageAddrs, err := startOpts.StorageCluster.allPublicAddrs(ctx) + storageAddrs, err := startOpts.StorageCluster.allPublicAddrs(ctx, l) if err != nil { return nil, err } @@ -1196,7 +1196,7 @@ func (c *SyncedCluster) maybeScaleMem(val int) int { func (c *SyncedCluster) initializeCluster(ctx context.Context, l *logger.Logger, node Node) error { l.Printf("%s: initializing cluster\n", c.Name) - cmd, err := c.generateInitCmd(ctx, node) + cmd, err := c.generateInitCmd(ctx, l, node) if err != nil { return err } @@ -1348,7 +1348,7 @@ func (c *SyncedCluster) generateClusterSettingCmd( pathPrefix = fmt.Sprintf("%s_", virtualCluster) } path := fmt.Sprintf("%s/%ssettings-initialized", c.NodeDir(node, 1 /* storeIndex */), pathPrefix) - port, err := c.NodePort(ctx, node, "" /* virtualClusterName */, 0 /* sqlInstance */) + port, err := c.NodePort(ctx, l, node, "" /* virtualClusterName */, 0 /* sqlInstance */) if err != nil { return "", err } @@ -1363,14 +1363,16 @@ func (c *SyncedCluster) generateClusterSettingCmd( return clusterSettingsCmd, nil } -func (c *SyncedCluster) generateInitCmd(ctx context.Context, node Node) (string, error) { +func (c *SyncedCluster) generateInitCmd( + ctx context.Context, l *logger.Logger, node Node, +) (string, error) { var initCmd string if c.IsLocal() { initCmd = fmt.Sprintf(`cd %s ; `, c.localVMDir(node)) } path := fmt.Sprintf("%s/%s", c.NodeDir(node, 1 /* storeIndex */), "cluster-bootstrapped") - port, err := c.NodePort(ctx, node, "" /* virtualClusterName */, 0 /* sqlInstance */) + port, err := c.NodePort(ctx, l, node, "" /* virtualClusterName */, 0 /* sqlInstance */) if err != nil { return "", err } @@ -1569,7 +1571,7 @@ func (c *SyncedCluster) createFixedBackupSchedule( node := c.Nodes[0] binary := cockroachNodeBinary(c, node) - port, err := c.NodePort(ctx, node, startOpts.VirtualClusterName, startOpts.SQLInstance) + port, err := c.NodePort(ctx, l, node, startOpts.VirtualClusterName, startOpts.SQLInstance) if err != nil { return err } diff --git a/pkg/roachprod/install/expander.go b/pkg/roachprod/install/expander.go index f32f53f9c87c..2a6ca231c2a8 100644 --- a/pkg/roachprod/install/expander.go +++ b/pkg/roachprod/install/expander.go @@ -206,7 +206,7 @@ func (e *expander) maybeExpandPgHost( switch strings.ToLower(m[1]) { case ":lb": - services, err := c.DiscoverServices(ctx, virtualClusterName, ServiceTypeSQL, ServiceInstancePredicate(sqlInstance)) + services, err := c.DiscoverServices(ctx, l, virtualClusterName, ServiceTypeSQL, ServiceInstancePredicate(sqlInstance)) if err != nil { return "", false, err } @@ -251,7 +251,7 @@ func (e *expander) maybeExpandPgPort( if e.pgPorts == nil { e.pgPorts = make(map[Node]string, len(c.VMs)) for _, node := range allNodes(len(c.VMs)) { - desc, err := c.DiscoverService(ctx, node, virtualClusterName, ServiceTypeSQL, sqlInstance) + desc, err := c.DiscoverService(ctx, l, node, virtualClusterName, ServiceTypeSQL, sqlInstance) if err != nil { return s, false, err } @@ -276,7 +276,7 @@ func (e *expander) maybeExpandUIPort( e.uiPorts = make(map[Node]string, len(c.VMs)) for _, node := range allNodes(len(c.VMs)) { // TODO(herko): Add support for separate-process services. - e.uiPorts[node] = fmt.Sprint(c.NodeUIPort(ctx, node, "" /* virtualClusterName */, 0 /* sqlInstance */)) + e.uiPorts[node] = fmt.Sprint(c.NodeUIPort(ctx, l, node, "" /* virtualClusterName */, 0 /* sqlInstance */)) } } diff --git a/pkg/roachprod/install/services.go b/pkg/roachprod/install/services.go index dfe499263226..723250b72627 100644 --- a/pkg/roachprod/install/services.go +++ b/pkg/roachprod/install/services.go @@ -134,6 +134,7 @@ func serviceNameComponents(name string) (string, ServiceType, error) { // nodes. func (c *SyncedCluster) DiscoverServices( ctx context.Context, + l *logger.Logger, virtualClusterName string, serviceType ServiceType, predicates ...ServicePredicate, @@ -145,7 +146,7 @@ func (c *SyncedCluster) DiscoverServices( mu := syncutil.Mutex{} records := make([]vm.DNSRecord, 0) err := vm.FanOutDNS(c.VMs, func(dnsProvider vm.DNSProvider, _ vm.List) error { - r, lookupErr := dnsProvider.LookupSRVRecords(ctx, serviceDNSName(dnsProvider, virtualClusterName, serviceType, c.Name)) + r, lookupErr := dnsProvider.LookupSRVRecords(ctx, l, serviceDNSName(dnsProvider, virtualClusterName, serviceType, c.Name)) if lookupErr != nil { return lookupErr } @@ -169,6 +170,7 @@ func (c *SyncedCluster) DiscoverServices( // for the service type. func (c *SyncedCluster) DiscoverService( ctx context.Context, + l *logger.Logger, node Node, virtualClusterName string, serviceType ServiceType, @@ -184,7 +186,7 @@ func (c *SyncedCluster) DiscoverService( // This call should return service descriptors for the storage // service and for external-process virtual clusters. services, err := c.DiscoverServices( - ctx, virtualClusterName, serviceType, + ctx, l, virtualClusterName, serviceType, ServiceNodePredicate(node), ServiceModePredicate(ServiceModeExternal), ServiceInstancePredicate(sqlInstance), ) if err != nil { @@ -199,7 +201,7 @@ func (c *SyncedCluster) DiscoverService( // shared-process virtual cluster. Find the corresponding system // service, if any. services, err = c.DiscoverServices( - ctx, SystemInterfaceName, serviceType, ServiceNodePredicate(node), + ctx, l, SystemInterfaceName, serviceType, ServiceNodePredicate(node), ) if err != nil { return ServiceDesc{}, err @@ -271,15 +273,15 @@ func (c *SyncedCluster) ListLoadBalancers(l *logger.Logger) ([]vm.ServiceAddress // MapServices discovers all service types for a given virtual cluster // and instance and maps it by node and service type. func (c *SyncedCluster) MapServices( - ctx context.Context, virtualClusterName string, instance int, + ctx context.Context, l *logger.Logger, virtualClusterName string, instance int, ) (NodeServiceMap, error) { nodeFilter := ServiceNodePredicate(c.Nodes...) instanceFilter := ServiceInstancePredicate(instance) - sqlServices, err := c.DiscoverServices(ctx, virtualClusterName, ServiceTypeSQL, nodeFilter, instanceFilter) + sqlServices, err := c.DiscoverServices(ctx, l, virtualClusterName, ServiceTypeSQL, nodeFilter, instanceFilter) if err != nil { return nil, err } - uiServices, err := c.DiscoverServices(ctx, virtualClusterName, ServiceTypeUI, nodeFilter, instanceFilter) + uiServices, err := c.DiscoverServices(ctx, l, virtualClusterName, ServiceTypeUI, nodeFilter, instanceFilter) if err != nil { return nil, err } @@ -297,7 +299,9 @@ func (c *SyncedCluster) MapServices( // RegisterServices registers services with the DNS provider. This function is // lenient and will not return an error if no DNS provider is available to // register the service. -func (c *SyncedCluster) RegisterServices(ctx context.Context, services ServiceDescriptors) error { +func (c *SyncedCluster) RegisterServices( + ctx context.Context, l *logger.Logger, services ServiceDescriptors, +) error { servicesByDNSProvider := make(map[string]ServiceDescriptors) for _, desc := range services { dnsProvider := c.VMs[desc.Node-1].DNSProvider @@ -323,7 +327,7 @@ func (c *SyncedCluster) RegisterServices(ctx context.Context, services ServiceDe } records = append(records, vm.CreateSRVRecord(name, srvData)) } - err := dnsProvider.CreateRecords(ctx, records...) + err := dnsProvider.CreateRecords(ctx, l, records...) if err != nil { return err } diff --git a/pkg/roachprod/install/services_test.go b/pkg/roachprod/install/services_test.go index be29ae4b0d72..87d795f2854d 100644 --- a/pkg/roachprod/install/services_test.go +++ b/pkg/roachprod/install/services_test.go @@ -37,7 +37,7 @@ func TestServicePorts(t *testing.T) { z2NS := local.NewDNSProvider(t.TempDir(), "z2") vm.Providers["p2"] = &testProvider{DNSProvider: z2NS} - err := z1NS.CreateRecords(ctx, + err := z1NS.CreateRecords(ctx, nil, vm.CreateSRVRecord(serviceDNSName(z1NS, "t1", ServiceTypeSQL, clusterName), net.SRV{ Target: "host1.rp.", Port: 12345, @@ -45,7 +45,7 @@ func TestServicePorts(t *testing.T) { ) require.NoError(t, err) - err = z2NS.CreateRecords(ctx, + err = z2NS.CreateRecords(ctx, nil, vm.CreateSRVRecord(serviceDNSName(z2NS, "t1", ServiceTypeSQL, clusterName), net.SRV{ Target: "host1.rp.", Port: 12346, @@ -72,7 +72,7 @@ func TestServicePorts(t *testing.T) { Nodes: allNodes(2), } - descriptors, err := c.DiscoverServices(context.Background(), "t1", ServiceTypeSQL, ServiceNodePredicate(c.Nodes...)) + descriptors, err := c.DiscoverServices(context.Background(), nil, "t1", ServiceTypeSQL, ServiceNodePredicate(c.Nodes...)) sort.Slice(descriptors, func(i, j int) bool { return descriptors[i].Port < descriptors[j].Port }) @@ -226,11 +226,11 @@ func TestMultipleRegistrations(t *testing.T) { verify := func(c *SyncedCluster, servicesToRegister [][]ServiceDesc) bool { for _, services := range servicesToRegister { if len(services) == 0 { - err := testDNS.DeleteRecordsBySubdomain(ctx, c.Name) + err := testDNS.DeleteRecordsBySubdomain(ctx, nil, c.Name) require.NoError(t, err) continue } - err := c.RegisterServices(ctx, services) + err := c.RegisterServices(ctx, nil, services) require.NoError(t, err) } return true diff --git a/pkg/roachprod/install/session.go b/pkg/roachprod/install/session.go index 44525cdca6ce..7150bf757edc 100644 --- a/pkg/roachprod/install/session.go +++ b/pkg/roachprod/install/session.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/config" rperrors "github.com/cockroachdb/cockroach/pkg/roachprod/errors" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" ) @@ -40,6 +41,7 @@ type remoteSession struct { *exec.Cmd cancel func() logfile string // captures ssh -vvv + logger *logger.Logger } type remoteCommand struct { @@ -127,7 +129,7 @@ func newRemoteSession(l *logger.Logger, command *remoteCommand) *remoteSession { args = append(args, command.cmd) ctx, cancel := context.WithCancel(context.Background()) fullCmd := exec.CommandContext(ctx, "ssh", args...) - return &remoteSession{fullCmd, cancel, logfile} + return &remoteSession{fullCmd, cancel, logfile, l} } func (s *remoteSession) errWithDebug(err error) error { @@ -147,6 +149,7 @@ func (s *remoteSession) CombinedOutput(ctx context.Context) ([]byte, error) { commandFinished := make(chan struct{}) go func() { + cli.MaybeLogCmd(context.Background(), s.logger, s.Cmd) b, err = s.Cmd.CombinedOutput() err = s.errWithDebug(err) close(commandFinished) @@ -165,6 +168,7 @@ func (s *remoteSession) Run(ctx context.Context) error { var err error commandFinished := make(chan struct{}) go func() { + cli.MaybeLogCmd(context.Background(), s.logger, s.Cmd) err = s.errWithDebug(s.Cmd.Run()) close(commandFinished) }() @@ -179,6 +183,7 @@ func (s *remoteSession) Run(ctx context.Context) error { } func (s *remoteSession) Start() error { + cli.MaybeLogCmd(context.Background(), s.logger, s.Cmd) return s.errWithDebug(s.Cmd.Start()) } diff --git a/pkg/roachprod/logger/BUILD.bazel b/pkg/roachprod/logger/BUILD.bazel index aa095f938843..9eadb38c309b 100644 --- a/pkg/roachprod/logger/BUILD.bazel +++ b/pkg/roachprod/logger/BUILD.bazel @@ -5,6 +5,7 @@ go_library( srcs = [ "log.go", "log_redirect.go", + "test_utils.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/logger", visibility = ["//visibility:public"], @@ -13,6 +14,7 @@ go_library( "//pkg/util/log/logconfig", "//pkg/util/log/logpb", "//pkg/util/syncutil", + "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/roachprod/logger/log.go b/pkg/roachprod/logger/log.go index 2900200d2c0a..e1835c9e2800 100644 --- a/pkg/roachprod/logger/log.go +++ b/pkg/roachprod/logger/log.go @@ -126,8 +126,8 @@ func (cfg *Config) NewLogger(path string) (*Logger, error) { return &Logger{ Stdout: newSafeWriter(stdout), Stderr: newSafeWriter(stderr), - stdoutL: log.New(os.Stdout, cfg.Prefix, logFlags), - stderrL: log.New(os.Stderr, cfg.Prefix, logFlags), + stdoutL: log.New(stdout, cfg.Prefix, logFlags), + stderrL: log.New(stderr, cfg.Prefix, logFlags), }, nil } diff --git a/pkg/roachprod/logger/log_redirect_test.go b/pkg/roachprod/logger/log_redirect_test.go index 8d8dcf88a850..8e61ea981966 100644 --- a/pkg/roachprod/logger/log_redirect_test.go +++ b/pkg/roachprod/logger/log_redirect_test.go @@ -7,53 +7,18 @@ package logger import ( "context" - "strings" "testing" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/stretchr/testify/require" ) -type mockLogger struct { - logger *Logger - writer *mockWriter -} - -type mockWriter struct { - lines []string -} - -func (w *mockWriter) Write(p []byte) (n int, err error) { - w.lines = append(w.lines, string(p)) - return len(p), nil -} - -func newMockLogger(t *testing.T) *mockLogger { - writer := &mockWriter{} - logConf := Config{Stdout: writer, Stderr: writer} - l, err := logConf.NewLogger("" /* path */) - require.NoError(t, err) - return &mockLogger{logger: l, writer: writer} -} - -func requireLine(t *testing.T, l *mockLogger, line string) { - t.Helper() - found := false - for _, logLine := range l.writer.lines { - if strings.Contains(logLine, line) { - found = true - break - } - } - require.True(t, found, "expected line not found: %s", line) -} - func TestLogRedirect(t *testing.T) { - l := newMockLogger(t) - TestingCRDBLogConfig(l.logger) + l := NewMockLogger(t) + TestingCRDBLogConfig(l.Logger) ctx := context.Background() log.Infof(ctx, "[simple test]") - requireLine(t, l, "[simple test]") + RequireLine(t, l, "[simple test]") require.Equal(t, 1, len(l.writer.lines)) } diff --git a/pkg/roachprod/logger/test_utils.go b/pkg/roachprod/logger/test_utils.go new file mode 100644 index 000000000000..e2088bdc3de6 --- /dev/null +++ b/pkg/roachprod/logger/test_utils.go @@ -0,0 +1,52 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package logger + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type mockLogger struct { + Logger *Logger + writer *mockWriter +} + +type mockWriter struct { + lines []string +} + +func (w *mockWriter) Write(p []byte) (n int, err error) { + w.lines = append(w.lines, string(p)) + return len(p), nil +} + +// Creates a logger whose entire output (stdout/stderr included) is redirected to mockWriter. +func NewMockLogger(t *testing.T) *mockLogger { + writer := &mockWriter{} + logConf := Config{Stdout: writer, Stderr: writer} + l, err := logConf.NewLogger("") + require.NoError(t, err) + return &mockLogger{Logger: l, writer: writer} +} + +func RequireLine(t *testing.T, l *mockLogger, line string) { + t.Helper() + found := false + for _, logLine := range l.writer.lines { + if strings.Contains(logLine, line) { + found = true + break + } + } + require.True(t, found, "expected line not found: %s", line) +} + +func RequireEqual(t *testing.T, l *mockLogger, expectedLines []string) { + require.Equal(t, expectedLines, l.writer.lines) +} diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 88a57ab98fcc..f72c75477cbf 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -44,6 +44,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/aws" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/azure" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/gce" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/local" "github.com/cockroachdb/cockroach/pkg/server/debug/replay" @@ -681,6 +682,8 @@ func SetupSSH(ctx context.Context, l *logger.Logger, clusterName string) error { for _, v := range cloudCluster.VMs { cmd := exec.Command("ssh-keygen", "-R", v.PublicIP) + cli.MaybeLogCmd(context.Background(), l, cmd) + out, err := cmd.CombinedOutput() if err != nil { l.Printf("could not clear ssh key for hostname %s:\n%s", v.PublicIP, string(out)) @@ -702,7 +705,7 @@ func SetupSSH(ctx context.Context, l *logger.Logger, clusterName string) error { } // Fetch public keys from gcloud to set up ssh access for all users into the // shared ubuntu user. - authorizedKeys, err := gce.GetUserAuthorizedKeys() + authorizedKeys, err := gce.GetUserAuthorizedKeys(l) if err != nil { return errors.Wrap(err, "failed to retrieve authorized keys from gcloud") } @@ -818,7 +821,7 @@ func updatePrometheusTargets( go func(index int, v vm.VM) { defer wg.Done() // only gce is supported for prometheus - desc, err := c.DiscoverService(ctx, install.Node(index), "", install.ServiceTypeUI, 0) + desc, err := c.DiscoverService(ctx, l, install.Node(index), "", install.ServiceTypeUI, 0) if err != nil { l.Errorf("error getting the port for node %d: %v", index, err) return @@ -1083,7 +1086,7 @@ func PgURL( var urls []string for i, ip := range ips { - desc, err := c.DiscoverService(ctx, nodes[i], opts.VirtualClusterName, install.ServiceTypeSQL, opts.SQLInstance) + desc, err := c.DiscoverService(ctx, l, nodes[i], opts.VirtualClusterName, install.ServiceTypeSQL, opts.SQLInstance) if err != nil { return nil, err } @@ -1138,7 +1141,7 @@ func urlGenerator( port := uConfig.port if port == 0 { desc, err := c.DiscoverService( - ctx, node, uConfig.virtualClusterName, install.ServiceTypeUI, uConfig.sqlInstance, + ctx, l, node, uConfig.virtualClusterName, install.ServiceTypeUI, uConfig.sqlInstance, ) if err != nil { return nil, err @@ -1157,6 +1160,8 @@ func urlGenerator( if uConfig.openInBrowser { cmd := browserCmd(url) + cli.MaybeLogCmd(context.Background(), l, cmd) + if err := cmd.Run(); err != nil { return nil, err } @@ -1220,7 +1225,7 @@ func SQLPorts( } var ports []int for _, node := range c.Nodes { - port, err := c.NodePort(ctx, node, virtualClusterName, sqlInstance) + port, err := c.NodePort(ctx, l, node, virtualClusterName, sqlInstance) if err != nil { return nil, errors.Wrapf(err, "Error discovering SQL Port for node %d", node) } @@ -1244,7 +1249,7 @@ func AdminPorts( } var ports []int for _, node := range c.Nodes { - port, err := c.NodeUIPort(ctx, node, virtualClusterName, sqlInstance) + port, err := c.NodeUIPort(ctx, l, node, virtualClusterName, sqlInstance) if err != nil { return nil, errors.Wrapf(err, "Error discovering UI Port for node %d", node) } @@ -1294,7 +1299,7 @@ func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts Pprof func(ctx context.Context, node install.Node) (*install.RunResultDetails, error) { res := &install.RunResultDetails{Node: node} host := c.Host(node) - port, err := c.NodeUIPort(ctx, node, "" /* virtualClusterName */, 0 /* sqlInstance */) + port, err := c.NodeUIPort(ctx, l, node, "" /* virtualClusterName */, 0 /* sqlInstance */) if err != nil { return nil, err } @@ -1374,6 +1379,8 @@ func Pprof(ctx context.Context, l *logger.Logger, clusterName string, opts Pprof file) waitCommands = append(waitCommands, cmd) + cli.MaybeLogCmd(ctx, l, cmd) + if err := cmd.Start(); err != nil { return err } @@ -2485,7 +2492,7 @@ func DestroyDNS(ctx context.Context, l *logger.Logger, clusterName string) error return err } return vm.FanOutDNS(c.VMs, func(p vm.DNSProvider, vms vm.List) error { - return p.DeleteRecordsBySubdomain(ctx, c.Name) + return p.DeleteRecordsBySubdomain(ctx, l, c.Name) }) } @@ -2555,7 +2562,7 @@ func sendCaptureCommand( httpClient := httputil.NewClientWithTimeout(0 /* timeout: None */) _, _, err := c.ParallelE(ctx, l, install.WithNodes(nodes).WithDisplay(fmt.Sprintf("Performing workload capture %s", action)), func(ctx context.Context, node install.Node) (*install.RunResultDetails, error) { - port, err := c.NodeUIPort(ctx, node, "" /* virtualClusterName */, 0 /* sqlInstance */) + port, err := c.NodeUIPort(ctx, l, node, "" /* virtualClusterName */, 0 /* sqlInstance */) if err != nil { return nil, err } @@ -2697,7 +2704,7 @@ func CreateLoadBalancer( // Find the SQL ports for the service on all nodes. services, err := c.DiscoverServices( - ctx, virtualClusterName, install.ServiceTypeSQL, + ctx, l, virtualClusterName, install.ServiceTypeSQL, install.ServiceNodePredicate(c.TargetNodes()...), install.ServiceInstancePredicate(sqlInstance), ) if err != nil { @@ -2759,7 +2766,7 @@ func LoadBalancerPgURL( return "", err } - services, err := c.DiscoverServices(ctx, opts.VirtualClusterName, install.ServiceTypeSQL, + services, err := c.DiscoverServices(ctx, l, opts.VirtualClusterName, install.ServiceTypeSQL, install.ServiceInstancePredicate(opts.SQLInstance)) if err != nil { return "", err @@ -2786,7 +2793,7 @@ func LoadBalancerIP( if err != nil { return "", err } - services, err := c.DiscoverServices(ctx, virtualClusterName, install.ServiceTypeSQL, + services, err := c.DiscoverServices(ctx, l, virtualClusterName, install.ServiceTypeSQL, install.ServiceInstancePredicate(sqlInstance)) if err != nil { return "", err diff --git a/pkg/roachprod/vm/aws/BUILD.bazel b/pkg/roachprod/vm/aws/BUILD.bazel index 57024382df95..330dbbc0d9e1 100644 --- a/pkg/roachprod/vm/aws/BUILD.bazel +++ b/pkg/roachprod/vm/aws/BUILD.bazel @@ -18,6 +18,7 @@ go_library( "//pkg/roachprod/config", "//pkg/roachprod/logger", "//pkg/roachprod/vm", + "//pkg/roachprod/vm/cli", "//pkg/roachprod/vm/flagstub", "//pkg/util/retry", "//pkg/util/syncutil", diff --git a/pkg/roachprod/vm/aws/aws.go b/pkg/roachprod/vm/aws/aws.go index 342a11378831..b6642642c6a6 100644 --- a/pkg/roachprod/vm/aws/aws.go +++ b/pkg/roachprod/vm/aws/aws.go @@ -7,6 +7,7 @@ package aws import ( + "context" _ "embed" "encoding/json" "fmt" @@ -21,6 +22,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/config" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/flagstub" "github.com/cockroachdb/cockroach/pkg/util/retry" "github.com/cockroachdb/cockroach/pkg/util/syncutil" @@ -35,7 +37,11 @@ import ( const ProviderName = "aws" // providerInstance is the instance to be registered into vm.Providers by Init. -var providerInstance = &Provider{} +var providerInstance = &Provider{ + CLIProvider: cli.CLIProvider{ + CLICommand: "aws", + }, +} //go:embed config.json var configJson []byte @@ -62,6 +68,7 @@ func Init() error { haveRequiredVersion := func() bool { cmd := exec.Command("aws", "--version") + cli.MaybeLogCmd(context.Background(), config.Logger, cmd) output, err := cmd.Output() if err != nil { return false @@ -265,9 +272,6 @@ type ProviderOpts struct { // Provider implements the vm.Provider interface for AWS. type Provider struct { - // Profile to manage cluster in - Profile string - // Path to json for aws configuration, defaults to predefined configuration Config *awsConfig @@ -277,6 +281,8 @@ type Provider struct { // aws accounts to perform action in, used by gcCmd only as it clean ups multiple aws accounts AccountIDs []string + + cli.CLIProvider } func (p *Provider) SupportsSpotVMs() bool { @@ -300,7 +306,7 @@ func (p *Provider) GetPreemptedSpotVMs( } args = append(args, vmList.ProviderIDs()...) var describeInstancesResponse DescribeInstancesOutput - err = p.runJSONCommand(l, args, &describeInstancesResponse) + err = p.RunJSONCommand(context.Background(), l, args, &describeInstancesResponse) if err != nil { // if the describe-instances operation fails with the error InvalidInstanceID.NotFound, // we assume that the instance has been preempted and describe-instances operation is attempted one hour after the instance termination @@ -371,7 +377,7 @@ func (p *Provider) GetVMSpecs( } args = append(args, list.ProviderIDs()...) var describeInstancesResponse DescribeInstancesOutput - err := p.runJSONCommand(l, args, &describeInstancesResponse) + err := p.RunJSONCommand(context.Background(), l, args, &describeInstancesResponse) if err != nil { return nil, errors.Wrapf(err, "error describing instances in region %s: ", region) } @@ -613,7 +619,7 @@ func (p *Provider) editLabels( regionArgs = append(regionArgs, list.ProviderIDs()...) g.Go(func() error { - _, err := p.runCommand(l, regionArgs) + _, err := p.RunCommand(context.Background(), l, regionArgs) return err }) } @@ -712,6 +718,9 @@ func (p *Provider) Create( if err := g.Wait(); err != nil { return err } + if config.DryRun { + return nil + } return p.waitForIPs(l, names, regions, providerOpts) } @@ -787,7 +796,7 @@ func (p *Provider) Delete(l *logger.Logger, vms vm.List) error { if len(data.TerminatingInstances) > 0 { _ = data.TerminatingInstances[0].InstanceID // silence unused warning } - return p.runJSONCommand(l, args, &data) + return p.RunJSONCommand(context.Background(), l, args, &data) }) } return g.Wait() @@ -808,7 +817,7 @@ func (p *Provider) Reset(l *logger.Logger, vms vm.List) error { } args = append(args, list.ProviderIDs()...) g.Go(func() error { - _, e := p.runCommand(l, args) + _, e := p.RunCommand(context.Background(), l, args) return e }) } @@ -857,7 +866,8 @@ func (p *Provider) iamGetUser(l *logger.Logger) (string, error) { } } args := []string{"iam", "get-user"} - err := p.runJSONCommand(l, args, &userInfo) + err := p.RunJSONCommand(context.Background(), l, args, &userInfo, cli.AlwaysExecute()) + if err != nil { return "", err } @@ -874,7 +884,8 @@ func (p *Provider) stsGetCallerIdentity(l *logger.Logger) (string, error) { Arn string } args := []string{"sts", "get-caller-identity"} - err := p.runJSONCommand(l, args, &userInfo) + err := p.RunJSONCommand(context.Background(), l, args, &userInfo, cli.AlwaysExecute()) + if err != nil { return "", err } @@ -983,8 +994,8 @@ func (p *Provider) getVolumesForInstance( "--region", region, "--filters", "Name=attachment.instance-id,Values=" + instanceID, } + err = p.RunJSONCommand(context.Background(), l, getVolumesArgs, &volumeOut) - err = p.runJSONCommand(l, getVolumesArgs, &volumeOut) if err != nil { return vols, err } @@ -1087,7 +1098,8 @@ func (p *Provider) listRegion( "--region", region, } var describeInstancesResponse DescribeInstancesOutput - err := p.runJSONCommand(l, args, &describeInstancesResponse) + err := p.RunJSONCommand(context.Background(), l, args, &describeInstancesResponse, cli.AlwaysExecute()) + if err != nil { return nil, err } @@ -1261,9 +1273,12 @@ func (p *Provider) runInstance( if err != nil { return errors.Wrapf(err, "could not write AWS startup script to temp file") } - defer func() { - _ = os.Remove(filename) - }() + // Keep startup script in dry-run mode. + if !config.DryRun { + defer func() { + _ = os.Remove(filename) + }() + } withFlagOverride := func(cfg string, fl *string) string { if *fl == "" { @@ -1323,7 +1338,7 @@ func (p *Provider) runInstance( //todo(babusrithar): Add fallback to on-demand instances if spot instances are not available. } runInstancesOutput := RunInstancesOutput{} - return p.runJSONCommand(l, args, &runInstancesOutput) + return p.RunJSONCommand(context.Background(), l, args, &runInstancesOutput) } // runSpotInstance uses run-instances command to create a spot instance. @@ -1337,7 +1352,7 @@ func runSpotInstance(l *logger.Logger, p *Provider, args []string, regionName st fmt.Sprintf("MarketType=spot,SpotOptions={SpotInstanceType=one-time,"+ "InstanceInterruptionBehavior=terminate}")) runInstancesOutput := RunInstancesOutput{} - err := p.runJSONCommand(l, spotArgs, &runInstancesOutput) + err := p.RunJSONCommand(context.Background(), l, spotArgs, &runInstancesOutput) if err != nil { return err } @@ -1390,7 +1405,7 @@ func cancelSpotRequest( "--region", regionName, "--spot-instance-request-ids", spotInstanceRequestId, } - err := p.runJSONCommand(l, csrArgs, &CancelSpotInstanceRequestsOutput{}) + err := p.RunJSONCommand(context.Background(), l, csrArgs, &CancelSpotInstanceRequestsOutput{}) if err != nil { // This code path is not expected to be hit, but if it does, we should return the error, so that roachprod // can destroy the cluster being created. @@ -1409,7 +1424,7 @@ func describeSpotInstanceRequest( "--spot-instance-request-ids", spotInstanceRequestId, } var describeSpotInstanceRequestsOutput DescribeSpotInstanceRequestsOutput - err := p.runJSONCommand(l, dsirArgs, &describeSpotInstanceRequestsOutput) + err := p.RunJSONCommand(context.Background(), l, dsirArgs, &describeSpotInstanceRequestsOutput) if err != nil { return DescribeSpotInstanceRequestsOutput{}, err } @@ -1451,7 +1466,7 @@ func getSpotInstanceRequestId( "--instance-ids", instanceId, } var describeInstancesResponse DescribeInstancesOutput - err := p.runJSONCommand(l, diArgs, &describeInstancesResponse) + err := p.RunJSONCommand(context.Background(), l, diArgs, &describeInstancesResponse) if err != nil { return "", err } @@ -1547,7 +1562,8 @@ func (p *Provider) AttachVolume(l *logger.Logger, volume vm.Volume, vm *vm.VM) ( } var commandResponse attachJsonResponse - err := p.runJSONCommand(l, args, &commandResponse) + err := p.RunJSONCommand(context.Background(), l, args, &commandResponse) + if err != nil { return "", err } @@ -1564,7 +1580,8 @@ func (p *Provider) AttachVolume(l *logger.Logger, volume vm.Volume, vm *vm.VM) ( "--block-device-mappings", "DeviceName=" + deviceName + ",Ebs={DeleteOnTermination=true,VolumeId=" + volume.ProviderResourceID + "}", } - _, err = p.runCommand(l, args) + _, err = p.RunCommand(context.Background(), l, args) + if err != nil { return "", err } @@ -1636,7 +1653,7 @@ func (p *Provider) CreateVolume( } args = append(args, "--size", strconv.Itoa(vco.Size)) var volumeDetails createVolume - err = p.runJSONCommand(l, args, &volumeDetails) + err = p.RunJSONCommand(context.Background(), l, args, &volumeDetails) if err != nil { return vol, err } @@ -1659,7 +1676,7 @@ func (p *Provider) CreateVolume( "--query", "Volumes[*].State", } for waitForVolume.Next() { - err = p.runJSONCommand(l, args, &state) + err = p.RunJSONCommand(context.Background(), l, args, &state) if len(state) > 0 && state[0] == "available" { close(waitForVolumeCloser) } @@ -1722,9 +1739,8 @@ func (p *Provider) CreateVolumeSnapshot( "--volume-id", volume.ProviderResourceID, "--tag-specifications", "ResourceType=snapshot,Tags=[" + strings.Join(tags, ",") + "]", } - var so snapshotOutput - if err := p.runJSONCommand(l, args, &so); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &so); err != nil { return vm.VolumeSnapshot{}, err } return vm.VolumeSnapshot{ diff --git a/pkg/roachprod/vm/aws/keys.go b/pkg/roachprod/vm/aws/keys.go index 5eb77425741e..b1dd09251ade 100644 --- a/pkg/roachprod/vm/aws/keys.go +++ b/pkg/roachprod/vm/aws/keys.go @@ -6,6 +6,7 @@ package aws import ( + "context" "crypto/sha1" "encoding/base64" "fmt" @@ -28,7 +29,8 @@ func (p *Provider) sshKeyExists(l *logger.Logger, keyName, region string) (bool, "ec2", "describe-key-pairs", "--region", region, } - err := p.runJSONCommand(l, args, &data) + err := p.RunJSONCommand(context.Background(), l, args, &data) + if err != nil { return false, err } @@ -72,7 +74,8 @@ func (p *Provider) sshKeyImport(l *logger.Logger, keyName, region string) error "--public-key-material", fmt.Sprintf("fileb://%s", sshPublicKeyPath), "--tag-specifications", tagSpecs, } - err = p.runJSONCommand(l, args, &data) + err = p.RunJSONCommand(context.Background(), l, args, &data) + // If two roachprod instances run at the same time with the same key, they may // race to upload the key pair. if err == nil || strings.Contains(err.Error(), "InvalidKeyPair.Duplicate") { diff --git a/pkg/roachprod/vm/aws/support.go b/pkg/roachprod/vm/aws/support.go index 801b5c260696..4160f2afb188 100644 --- a/pkg/roachprod/vm/aws/support.go +++ b/pkg/roachprod/vm/aws/support.go @@ -6,16 +6,10 @@ package aws import ( - "bytes" - "encoding/json" "os" - "os/exec" - "strings" "text/template" - "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" - "github.com/cockroachdb/errors" ) // Both M5 and I3 machines expose their EBS or local SSD volumes as NVMe block @@ -302,41 +296,6 @@ func writeStartupScript( return tmpfile.Name(), nil } -// runCommand is used to invoke an AWS command. -func (p *Provider) runCommand(l *logger.Logger, args []string) ([]byte, error) { - - if p.Profile != "" { - args = append(args[:len(args):len(args)], "--profile", p.Profile) - } - var stderrBuf bytes.Buffer - cmd := exec.Command("aws", args...) - cmd.Stderr = &stderrBuf - output, err := cmd.Output() - if err != nil { - if exitErr := (*exec.ExitError)(nil); errors.As(err, &exitErr) { - l.Printf("%s", string(exitErr.Stderr)) - } - return nil, errors.Wrapf(err, "failed to run: aws %s: stderr: %v", - strings.Join(args, " "), stderrBuf.String()) - } - return output, nil -} - -// runJSONCommand invokes an aws command and parses the json output. -func (p *Provider) runJSONCommand(l *logger.Logger, args []string, parsed interface{}) error { - // Force json output in case the user has overridden the default behavior. - args = append(args[:len(args):len(args)], "--output", "json") - rawJSON, err := p.runCommand(l, args) - if err != nil { - return err - } - if err := json.Unmarshal(rawJSON, &parsed); err != nil { - return errors.Wrapf(err, "failed to parse json %s", rawJSON) - } - - return nil -} - // regionMap collates VM instances by their region. func regionMap(vms vm.List) (map[string]vm.List, error) { // Fan out the work by region diff --git a/pkg/roachprod/vm/cli/BUILD.bazel b/pkg/roachprod/vm/cli/BUILD.bazel new file mode 100644 index 000000000000..ad9976adac3c --- /dev/null +++ b/pkg/roachprod/vm/cli/BUILD.bazel @@ -0,0 +1,25 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "cli", + srcs = ["cli.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli", + visibility = ["//visibility:public"], + deps = [ + "//pkg/roachprod/config", + "//pkg/roachprod/logger", + "@com_github_cockroachdb_errors//:errors", + ], +) + +go_test( + name = "cli_test", + size = "small", + srcs = ["cli_test.go"], + embed = [":cli"], + deps = [ + "//pkg/roachprod/config", + "//pkg/roachprod/logger", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/roachprod/vm/cli/cli.go b/pkg/roachprod/vm/cli/cli.go new file mode 100644 index 000000000000..a3474b34b80b --- /dev/null +++ b/pkg/roachprod/vm/cli/cli.go @@ -0,0 +1,140 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package cli + +import ( + "bytes" + "context" + "encoding/json" + "os/exec" + "regexp" + "strings" + + "github.com/cockroachdb/cockroach/pkg/roachprod/config" + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/cockroachdb/errors" +) + +type ( + CLIProvider struct { + // Name of CLI command to execute. + CLICommand string + + // Name of (auth.) profile to use. + Profile string + } + CommandSpec struct { + // If true, the command is _executed_, even in dry-run mode. + // E.g., FindActiveAccount is a prerequisite for Create; thus, it must be executed; otherwise, + // Create will never be invoked. + AlwaysExecute bool + } + + // Option for CommandSpec. + Option func(opts *CommandSpec) +) + +func AlwaysExecute() Option { + return func(spec *CommandSpec) { + spec.AlwaysExecute = true + } +} + +// Execute a command using CLIProvider's CLICommand. +// If command exits 0, return stdout, ignoring stderr. +// Otherwise, the returned error encapsulates both stdout and stderr, including the exit code. +func (p *CLIProvider) RunCommand( + ctx context.Context, l *logger.Logger, args []string, opts ...Option, +) ([]byte, error) { + cmdSpec := CommandSpec{} + for _, opt := range opts { + opt(&cmdSpec) + } + return p.runCommand(ctx, l, 2, args, cmdSpec) +} + +// runJSONCommand executes CLIProvider's CLICommand, parsing the output as JSON. +// If the JSON output is valid, it is unmarshaled into 'parsed'; otherwise, an error is returned. +func (p *CLIProvider) RunJSONCommand( + ctx context.Context, l *logger.Logger, args []string, parsed interface{}, opts ...Option, +) error { + if p.CLICommand == "aws" { + // Force json output in case the user has overridden the default behavior. + args = append(args[:len(args):len(args)], "--output", "json") + } + cmdSpec := CommandSpec{} + for _, opt := range opts { + opt(&cmdSpec) + } + rawJSON, err := p.runCommand(ctx, l, 3, args, cmdSpec) + if err != nil { + return err + } + + if !config.DryRun || cmdSpec.AlwaysExecute { + if err := json.Unmarshal(rawJSON, &parsed); err != nil { + return errors.Wrapf(err, "failed to parse json %s", rawJSON) + } + } + + return nil +} + +func (p *CLIProvider) runCommand( + ctx context.Context, l *logger.Logger, depth int, args []string, cmdSpec CommandSpec, +) ([]byte, error) { + if p.CLICommand == "" { + return nil, errors.New("CLICommand is not set") + } + if p.Profile != "" { + args = append(args[:len(args):len(args)], "--profile", p.Profile) + } + maybeLogCmdStr(ctx, l, depth+1, p.CLICommand, args) + // Bail out early if we're in dry-run mode and the command is not required to be executed. + if config.DryRun && !cmdSpec.AlwaysExecute { + return nil, nil + } + var stderrBuf bytes.Buffer + var cmd *exec.Cmd + + if ctx == nil { + cmd = exec.Command(p.CLICommand, args...) + } else { + cmd = exec.CommandContext(ctx, p.CLICommand, args...) + } + cmd.Stderr = &stderrBuf + output, err := cmd.Output() + if err != nil { + if exitErr := (*exec.ExitError)(nil); errors.As(err, &exitErr) { + config.Logger.Printf("%s", string(exitErr.Stderr)) + } + stderr := stderrBuf.Bytes() + // TODO(peter,ajwerner): Remove this hack once gcloud behaves when adding new zones. + // 'gcloud compute instances list --project cockroach-ephemeral --format json' + // would fail with, "Invalid value for field 'zone': 'asia-northeast2-a'. Unknown zone.", around the time when + // the region was added but not fully available. It remains unclear whether this can still happen with newly + // added GCP regions. One potential workaround is to constrain the list of zones. + if matched, _ := regexp.Match(`.*Unknown zone`, stderr); !matched { + return nil, errors.Wrapf(err, "failed to run: %s %s\nstdout: %s\nstderr: %s\n", + p.CLICommand, strings.Join(args, " "), bytes.TrimSpace(output), bytes.TrimSpace(stderr)) + } + } + return output, nil +} + +// MaybeLogCmd logs the full command string, in dry-run or verbose mode. +func MaybeLogCmd(ctx context.Context, l *logger.Logger, cmd *exec.Cmd) { + if config.DryRun || config.Verbose { + l.PrintfCtxDepth(ctx, 2, "exec: %s %s", cmd.Args[0], strings.Join(cmd.Args[1:], " ")) + } +} + +// Invoked internally by RunCommand. The depth is 3 because we want to ignore 2 internal calls plus PrintfCtxDepth +func maybeLogCmdStr(ctx context.Context, l *logger.Logger, depth int, cmd string, args []string) { + if config.DryRun || config.Verbose { + l.PrintfCtxDepth(ctx, depth+1, "exec: %s %s", cmd, strings.Join(args, " ")) + } +} diff --git a/pkg/roachprod/vm/cli/cli_test.go b/pkg/roachprod/vm/cli/cli_test.go new file mode 100644 index 000000000000..f9357699af92 --- /dev/null +++ b/pkg/roachprod/vm/cli/cli_test.go @@ -0,0 +1,113 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package cli + +import ( + "context" + "fmt" + "testing" + + "github.com/cockroachdb/cockroach/pkg/roachprod/config" + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/stretchr/testify/require" +) + +func TestRunCommand(t *testing.T) { + provider := CLIProvider{} + l := logger.NewMockLogger(t) + + setVerbose := func(verbose bool) { + config.Verbose = verbose + config.DryRun = false + } + setDryRun := func(dryRun bool) { + config.DryRun = dryRun + config.Verbose = false + } + // Restore global config. flags after we muck with them below. + verbose := config.Verbose + dryRun := config.DryRun + defer func() { + config.Verbose = verbose + config.DryRun = dryRun + }() + + // Missing CLIProvider.CLICommand + output, err := provider.RunCommand(context.Background(), l.Logger, nil) + require.Error(t, err) + require.Equal(t, "", string(output)) + logger.RequireEqual(t, l, nil) + + // Successful command, empty output. + provider.CLICommand = "true" + l = logger.NewMockLogger(t) + output, err = provider.RunCommand(context.Background(), l.Logger, nil) + require.NoError(t, err) + require.Equal(t, "", string(output)) + logger.RequireEqual(t, l, nil) + + // Successful command, non-empty output. + provider.CLICommand = "echo" + l = logger.NewMockLogger(t) + output, err = provider.RunCommand(context.Background(), l.Logger, []string{"-n", "foo"}) + require.NoError(t, err) + require.Equal(t, "foo", string(output)) + logger.RequireEqual(t, l, nil) + + // Erroneous command, output logged. + provider.CLICommand = "foobar" + l = logger.NewMockLogger(t) + setVerbose(true) + output, err = provider.RunCommand(context.Background(), l.Logger, nil) + require.ErrorContains(t, err, fmt.Sprintf("failed to run: %s \nstdout: \nstderr: \n: exec: \"%s\"", provider.CLICommand, provider.CLICommand)) + require.Equal(t, "", string(output)) + logger.RequireLine(t, l, "exec: foobar") + + // Erroneous command, `--dry-run`. + provider.CLICommand = "foobar" + l = logger.NewMockLogger(t) + setDryRun(true) + output, err = provider.RunCommand(context.Background(), l.Logger, nil) + require.NoError(t, err) + require.Equal(t, "", string(output)) + logger.RequireLine(t, l, "exec: foobar") + + // Successful command, `--verbose`. + provider.CLICommand = "echo" + l = logger.NewMockLogger(t) + setVerbose(true) + output, err = provider.RunCommand(context.Background(), l.Logger, []string{"-n", "foo"}) + require.NoError(t, err) + require.Equal(t, "foo", string(output)) + logger.RequireLine(t, l, "exec: echo -n foo") + + // Successful command, `--dry-run`. + provider.CLICommand = "echo" + l = logger.NewMockLogger(t) + setDryRun(true) + output, err = provider.RunCommand(context.Background(), l.Logger, []string{"-n", "foo"}) + require.NoError(t, err) + require.Equal(t, "", string(output)) + logger.RequireLine(t, l, "exec: echo -n foo") + + // Erroneous command, `--dry-run` with `AlwaysExecute` option. + provider.CLICommand = "foobar" + l = logger.NewMockLogger(t) + setDryRun(true) + output, err = provider.RunCommand(context.Background(), l.Logger, nil, AlwaysExecute()) + require.Error(t, err) + require.Equal(t, "", string(output)) + logger.RequireLine(t, l, "exec: foobar") + + // Successful command, `--dry-run` with `AlwaysExecute` option. + provider.CLICommand = "echo" + l = logger.NewMockLogger(t) + setDryRun(true) + output, err = provider.RunCommand(context.Background(), l.Logger, []string{"-n", "foo"}, AlwaysExecute()) + require.NoError(t, err) + require.Equal(t, "foo", string(output)) + logger.RequireLine(t, l, "exec: echo -n foo") +} diff --git a/pkg/roachprod/vm/dns.go b/pkg/roachprod/vm/dns.go index 9b9c41865804..cf2821a0c53e 100644 --- a/pkg/roachprod/vm/dns.go +++ b/pkg/roachprod/vm/dns.go @@ -12,6 +12,7 @@ import ( "regexp" "strconv" + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/errors" "golang.org/x/sync/errgroup" ) @@ -44,18 +45,18 @@ type DNSRecord struct { // management services. type DNSProvider interface { // CreateRecords creates DNS records. - CreateRecords(ctx context.Context, records ...DNSRecord) error + CreateRecords(ctx context.Context, l *logger.Logger, records ...DNSRecord) error // LookupSRVRecords looks up SRV records for the given service, proto, and // subdomain. The protocol is usually "tcp" and the subdomain is usually the // cluster name. The service is a combination of the virtual cluster name and // type of service. - LookupSRVRecords(ctx context.Context, name string) ([]DNSRecord, error) + LookupSRVRecords(ctx context.Context, l *logger.Logger, name string) ([]DNSRecord, error) // ListRecords lists all DNS records managed for the zone. - ListRecords(ctx context.Context) ([]DNSRecord, error) + ListRecords(ctx context.Context, l *logger.Logger) ([]DNSRecord, error) // DeleteRecordsBySubdomain deletes all DNS records with the given subdomain. - DeleteRecordsBySubdomain(ctx context.Context, subdomain string) error + DeleteRecordsBySubdomain(ctx context.Context, l *logger.Logger, subdomain string) error // DeleteRecordsByName deletes all DNS records with the given name. - DeleteRecordsByName(ctx context.Context, names ...string) error + DeleteRecordsByName(ctx context.Context, l *logger.Logger, names ...string) error // Domain returns the domain name (zone) of the DNS provider. Domain() string } diff --git a/pkg/roachprod/vm/gce/BUILD.bazel b/pkg/roachprod/vm/gce/BUILD.bazel index 96b39f2e481d..a6dd4a46d73c 100644 --- a/pkg/roachprod/vm/gce/BUILD.bazel +++ b/pkg/roachprod/vm/gce/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "//pkg/roachprod/logger", "//pkg/roachprod/ui", "//pkg/roachprod/vm", + "//pkg/roachprod/vm/cli", "//pkg/roachprod/vm/flagstub", "//pkg/util/randutil", "//pkg/util/retry", diff --git a/pkg/roachprod/vm/gce/dns.go b/pkg/roachprod/vm/gce/dns.go index 02143699c132..2252d9d23472 100644 --- a/pkg/roachprod/vm/gce/dns.go +++ b/pkg/roachprod/vm/gce/dns.go @@ -11,7 +11,6 @@ import ( "fmt" "net" "os" - "os/exec" "sort" "strconv" "strings" @@ -20,6 +19,7 @@ import ( rperrors "github.com/cockroachdb/cockroach/pkg/roachprod/errors" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli" "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/errors" @@ -57,12 +57,9 @@ var ( "roachprod-managed.crdb.io", ) ) - -var ErrDNSOperation = fmt.Errorf("error during Google Cloud DNS operation") - var _ vm.DNSProvider = &dnsProvider{} -type ExecFn func(cmd *exec.Cmd) ([]byte, error) +type ExecFn func(context.Context, *logger.Logger, []string, ...cli.Option) ([]byte, error) // dnsProvider implements the vm.DNSProvider interface. type dnsProvider struct { @@ -92,9 +89,7 @@ type dnsProvider struct { } func NewDNSProvider() *dnsProvider { - return NewDNSProviderWithExec(func(cmd *exec.Cmd) ([]byte, error) { - return cmd.CombinedOutput() - }) + return NewDNSProviderWithExec(cliProvider.RunCommand) } func NewDNSProviderWithExec(execFn ExecFn) *dnsProvider { @@ -118,7 +113,9 @@ func NewDNSProviderWithExec(execFn ExecFn) *dnsProvider { } // CreateRecords implements the vm.DNSProvider interface. -func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord) error { +func (n *dnsProvider) CreateRecords( + ctx context.Context, l *logger.Logger, records ...vm.DNSRecord, +) error { recordsByName := make(map[string][]vm.DNSRecord) for _, record := range records { // Ensure we use the normalised name for grouping records. @@ -128,7 +125,7 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord for name, recordGroup := range recordsByName { err := n.withRecordLock(name, func() error { - existingRecords, err := n.lookupSRVRecords(ctx, name) + existingRecords, err := n.lookupSRVRecords(ctx, l, name) if err != nil { return err } @@ -158,8 +155,7 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord "--zone", n.managedZone, "--rrdatas", strings.Join(data, ","), } - cmd := exec.CommandContext(ctx, "gcloud", args...) - out, err := n.execFn(cmd) + out, err := n.execFn(ctx, l, args) if err != nil { // Clear the cache entry if the operation failed, as the records may // have been partially updated. @@ -186,7 +182,9 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord } // LookupSRVRecords implements the vm.DNSProvider interface. -func (n *dnsProvider) LookupSRVRecords(ctx context.Context, name string) ([]vm.DNSRecord, error) { +func (n *dnsProvider) LookupSRVRecords( + ctx context.Context, l *logger.Logger, name string, +) ([]vm.DNSRecord, error) { var records []vm.DNSRecord var err error err = n.withRecordLock(name, func() error { @@ -195,27 +193,28 @@ func (n *dnsProvider) LookupSRVRecords(ctx context.Context, name string) ([]vm.D records, err = n.fastLookupSRVRecords(ctx, n.resolvers[rIdx], name, true) return err } - records, err = n.lookupSRVRecords(ctx, name) + records, err = n.lookupSRVRecords(ctx, l, name) return err }) return records, err } // ListRecords implements the vm.DNSProvider interface. -func (n *dnsProvider) ListRecords(ctx context.Context) ([]vm.DNSRecord, error) { - return n.listSRVRecords(ctx, "", dnsMaxResults) +func (n *dnsProvider) ListRecords(ctx context.Context, l *logger.Logger) ([]vm.DNSRecord, error) { + return n.listSRVRecords(ctx, l, "", dnsMaxResults) } // DeleteRecordsByName implements the vm.DNSProvider interface. -func (n *dnsProvider) DeleteRecordsByName(ctx context.Context, names ...string) error { +func (n *dnsProvider) DeleteRecordsByName( + ctx context.Context, l *logger.Logger, names ...string, +) error { for _, name := range names { err := n.withRecordLock(name, func() error { args := []string{"--project", n.dnsProject, "dns", "record-sets", "delete", name, "--type", string(vm.SRV), "--zone", n.managedZone, } - cmd := exec.CommandContext(ctx, "gcloud", args...) - out, err := n.execFn(cmd) + out, err := n.execFn(ctx, l, args) // Clear the cache entry regardless of the outcome. As the records may // have been partially deleted. n.clearCacheEntry(name) @@ -232,9 +231,11 @@ func (n *dnsProvider) DeleteRecordsByName(ctx context.Context, names ...string) } // DeleteRecordsBySubdomain implements the vm.DNSProvider interface. -func (n *dnsProvider) DeleteRecordsBySubdomain(ctx context.Context, subdomain string) error { +func (n *dnsProvider) DeleteRecordsBySubdomain( + ctx context.Context, l *logger.Logger, subdomain string, +) error { suffix := fmt.Sprintf("%s.%s.", subdomain, n.Domain()) - records, err := n.listSRVRecords(ctx, suffix, dnsMaxResults) + records, err := n.listSRVRecords(ctx, l, suffix, dnsMaxResults) if err != nil { return err } @@ -252,7 +253,7 @@ func (n *dnsProvider) DeleteRecordsBySubdomain(ctx context.Context, subdomain st delete(names, name) } } - return n.DeleteRecordsByName(ctx, maps.Keys(names)...) + return n.DeleteRecordsByName(ctx, l, maps.Keys(names)...) } // Domain implements the vm.DNSProvider interface. @@ -268,13 +269,15 @@ func (n *dnsProvider) Domain() string { // network problems. For lookups, we prefer this to using the gcloud command as // it is faster, and preferable when service information is being queried // regularly. -func (n *dnsProvider) lookupSRVRecords(ctx context.Context, name string) ([]vm.DNSRecord, error) { +func (n *dnsProvider) lookupSRVRecords( + ctx context.Context, l *logger.Logger, name string, +) ([]vm.DNSRecord, error) { // Check the cache first. if cachedRecords, ok := n.getCache(name); ok { return cachedRecords, nil } // Lookup the records, if no records are found in the cache. - records, err := n.listSRVRecords(ctx, name, dnsMaxResults) + records, err := n.listSRVRecords(ctx, l, name, dnsMaxResults) if err != nil { return nil, err } @@ -295,7 +298,7 @@ func (n *dnsProvider) lookupSRVRecords(ctx context.Context, name string) ([]vm.D // The data field of the records could be a comma-separated list of values if multiple // records are returned for the same name. func (n *dnsProvider) listSRVRecords( - ctx context.Context, filter string, limit int, + ctx context.Context, l *logger.Logger, filter string, limit int, ) ([]vm.DNSRecord, error) { args := []string{"--project", n.dnsProject, "dns", "record-sets", "list", "--limit", strconv.Itoa(limit), @@ -306,8 +309,7 @@ func (n *dnsProvider) listSRVRecords( if filter != "" { args = append(args, "--filter", filter) } - cmd := exec.CommandContext(ctx, "gcloud", args...) - res, err := n.execFn(cmd) + res, err := n.execFn(ctx, l, args, cli.AlwaysExecute()) if err != nil { return nil, rperrors.TransientFailure(errors.Wrapf(err, "output: %s", res), dnsProblemLabel) } @@ -405,12 +407,14 @@ func (p *dnsProvider) syncPublicDNS(l *logger.Logger, vms vm.List) (err error) { } defer f.Close() - // Keep imported zone file in dry run mode. - defer func() { - if err := os.Remove(f.Name()); err != nil { - l.Errorf("removing %s failed: %v", f.Name(), err) - } - }() + // Keep imported zone file in dry-run mode. + if !config.DryRun { + defer func() { + if err := os.Remove(f.Name()); err != nil { + l.Errorf("removing %s failed: %v", f.Name(), err) + } + }() + } var zoneBuilder strings.Builder for _, vm := range vms { @@ -426,11 +430,10 @@ func (p *dnsProvider) syncPublicDNS(l *logger.Logger, vms vm.List) (err error) { args := []string{"--project", p.dnsProject, "dns", "record-sets", "import", f.Name(), "-z", p.publicZone, "--delete-all-existing", "--zone-file-format"} - cmd := exec.Command("gcloud", args...) - output, err := cmd.CombinedOutput() + _, err = cliProvider.RunCommand(context.Background(), l, args) - return errors.Wrapf(err, - "Command: %s\nOutput: %s\nZone file contents:\n%s", - cmd, output, zoneBuilder.String(), - ) + if err != nil { + return errors.Wrapf(err, "Zone file contents:\n%s", zoneBuilder.String()) + } + return nil } diff --git a/pkg/roachprod/vm/gce/gcloud.go b/pkg/roachprod/vm/gce/gcloud.go index b60d780edf61..d13520aee8c1 100644 --- a/pkg/roachprod/vm/gce/gcloud.go +++ b/pkg/roachprod/vm/gce/gcloud.go @@ -6,9 +6,7 @@ package gce import ( - "bytes" "context" - "encoding/json" "fmt" "math/rand" "os" @@ -25,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/ui" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/flagstub" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" @@ -47,7 +46,11 @@ const ( ) // providerInstance is the instance to be registered into vm.Providers by Init. -var providerInstance = &Provider{} +var providerInstance = &Provider{ + CLIProvider: cli.CLIProvider{ + CLICommand: "gcloud", + }, +} var ( defaultDefaultProject = config.EnvOrDefaultString( @@ -112,31 +115,8 @@ func Init() error { providerInstance.defaultServiceAccount = defaultDefaultServiceAccount initialized = true - vm.Providers[ProviderName] = providerInstance - return nil -} - -func runJSONCommand(args []string, parsed interface{}) error { - cmd := exec.Command("gcloud", args...) - - rawJSON, err := cmd.Output() - if err != nil { - var stderr []byte - if exitErr := (*exec.ExitError)(nil); errors.As(err, &exitErr) { - stderr = exitErr.Stderr - } - // TODO(peter,ajwerner): Remove this hack once gcloud behaves when adding - // new zones. - if matched, _ := regexp.Match(`.*Unknown zone`, stderr); !matched { - return errors.Wrapf(err, "failed to run: gcloud %s\nstdout: %s\nstderr: %s\n", - strings.Join(args, " "), bytes.TrimSpace(rawJSON), bytes.TrimSpace(stderr)) - } - } - - if err := json.Unmarshal(rawJSON, &parsed); err != nil { - return errors.Wrapf(err, "failed to parse json %s: %v", rawJSON, rawJSON) - } + vm.Providers[ProviderName] = providerInstance return nil } @@ -373,6 +353,8 @@ type Provider struct { // The service account to use if the default project is in use and no // ServiceAccount was specified. defaultServiceAccount string + + cli.CLIProvider } // LogEntry represents a single log entry from the gcloud logging(stack driver) @@ -403,7 +385,7 @@ func (p *Provider) GetPreemptedSpotVMs( return nil, err } var logEntries []LogEntry - if err := runJSONCommand(args, &logEntries); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &logEntries); err != nil { l.Printf("Error running gcloud cli command: %v\n", err) return nil, err } @@ -431,7 +413,7 @@ func (p *Provider) GetHostErrorVMs( return nil, err } var logEntries []LogEntry - if err := runJSONCommand(args, &logEntries); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &logEntries); err != nil { l.Printf("Error running gcloud cli command: %v\n", err) return nil, err } @@ -460,7 +442,7 @@ func (p *Provider) GetVMSpecs( vmFullResourceName := "projects/" + p.GetProject() + "/zones/" + vmInstance.Zone + "/instances/" + vmInstance.Name args := []string{"compute", "instances", "describe", vmFullResourceName, "--format=json"} - if err := runJSONCommand(args, &vmSpec); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &vmSpec); err != nil { return nil, errors.Wrapf(err, "error describing instance %s in zone %s", vmInstance.Name, vmInstance.Zone) } name, ok := vmSpec["name"].(string) @@ -560,7 +542,7 @@ func (p *Provider) CreateVolumeSnapshot( } var createJsonResponse snapshotJson - if err := runJSONCommand(args, &createJsonResponse); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &createJsonResponse); err != nil { return vm.VolumeSnapshot{}, err } @@ -577,8 +559,7 @@ func (p *Provider) CreateVolumeSnapshot( "add-labels", vsco.Name, "--labels", s[:len(s)-1], } - cmd := exec.Command("gcloud", args...) - if _, err := cmd.CombinedOutput(); err != nil { + if _, err := p.RunCommand(context.Background(), l, args); err != nil { return vm.VolumeSnapshot{}, err } return vm.VolumeSnapshot{ @@ -612,7 +593,7 @@ func (p *Provider) ListVolumeSnapshots( } var snapshotsJSONResponse []snapshotJson - if err := runJSONCommand(args, &snapshotsJSONResponse); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &snapshotsJSONResponse); err != nil { return nil, err } @@ -644,8 +625,7 @@ func (p *Provider) DeleteVolumeSnapshots(l *logger.Logger, snapshots ...vm.Volum args = append(args, snapshot.Name) } - cmd := exec.Command("gcloud", args...) - if _, err := cmd.CombinedOutput(); err != nil { + if _, err := p.RunCommand(context.Background(), l, args); err != nil { return err } return nil @@ -714,7 +694,7 @@ func (p *Provider) CreateVolume( } var commandResponse []describeVolumeCommandResponse - err = runJSONCommand(args, &commandResponse) + err = p.RunJSONCommand(context.Background(), l, args, &commandResponse) if err != nil { return vm.Volume{}, err } @@ -742,8 +722,7 @@ func (p *Provider) CreateVolume( "--labels", s[:len(s)-1], "--zone", vco.Zone, } - cmd := exec.Command("gcloud", args...) - if _, err := cmd.CombinedOutput(); err != nil { + if _, err := p.RunCommand(context.Background(), l, args); err != nil { return vm.Volume{}, err } } @@ -769,8 +748,7 @@ func (p *Provider) DeleteVolume(l *logger.Logger, volume vm.Volume, vm *vm.VM) e "--disk", volume.ProviderResourceID, "--zone", volume.Zone, } - cmd := exec.Command("gcloud", args...) - if _, err := cmd.CombinedOutput(); err != nil { + if _, err := p.RunCommand(context.Background(), l, args); err != nil { return err } } @@ -784,8 +762,7 @@ func (p *Provider) DeleteVolume(l *logger.Logger, volume vm.Volume, vm *vm.VM) e "--zone", volume.Zone, "--quiet", } - cmd := exec.Command("gcloud", args...) - if _, err := cmd.CombinedOutput(); err != nil { + if _, err := p.RunCommand(context.Background(), l, args); err != nil { return err } } @@ -813,7 +790,7 @@ func (p *Provider) ListVolumes(l *logger.Logger, v *vm.VM) ([]vm.Volume, error) "--zone", v.Zone, "--format", "json(disks)", } - if err := runJSONCommand(args, &commandResponse); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &commandResponse); err != nil { return nil, err } attachedDisks = commandResponse.Disks @@ -834,7 +811,7 @@ func (p *Provider) ListVolumes(l *logger.Logger, v *vm.VM) ([]vm.Volume, error) "--filter", fmt.Sprintf("users:(%s)", v.Name), "--format", "json", } - if err := runJSONCommand(args, &describedVolumes); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &describedVolumes); err != nil { return nil, err } } @@ -900,7 +877,7 @@ func (p *Provider) AttachVolume(l *logger.Logger, volume vm.Volume, vm *vm.VM) ( } var commandResponse []instanceDisksResponse - if err := runJSONCommand(args, &commandResponse); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &commandResponse); err != nil { return "", err } found := false @@ -928,7 +905,7 @@ func (p *Provider) AttachVolume(l *logger.Logger, volume vm.Volume, vm *vm.VM) ( "--format=json(disks)", } - if err := runJSONCommand(args, &commandResponse); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &commandResponse); err != nil { return "", err } @@ -1157,11 +1134,9 @@ func (o *ProviderOpts) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { func (p *Provider) CleanSSH(l *logger.Logger) error { for _, prj := range p.GetProjects() { args := []string{"compute", "config-ssh", "--project", prj, "--quiet", "--remove"} - cmd := exec.Command("gcloud", args...) - - output, err := cmd.CombinedOutput() + _, err := p.RunCommand(context.Background(), l, args) if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + return err } } return nil @@ -1206,9 +1181,8 @@ func (p *Provider) editLabels( vmArgs = append(vmArgs, commonArgs...) g.Go(func() error { - cmd := exec.Command("gcloud", vmArgs...) - if b, err := cmd.CombinedOutput(); err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", vmArgs, string(b)) + if _, err := p.RunCommand(context.Background(), l, vmArgs); err != nil { + return err } return nil }) @@ -1460,10 +1434,8 @@ func createInstanceTemplates( createTemplateArgs = append(createTemplateArgs, "--labels", labelsArg) createTemplateArgs = append(createTemplateArgs, templateName) g.Go(func() error { - cmd := exec.Command("gcloud", createTemplateArgs...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", createTemplateArgs, output) + if _, err := cliProvider.RunCommand(context.Background(), l, createTemplateArgs); err != nil { + return err } return nil }) @@ -1511,10 +1483,8 @@ func createInstanceGroups( argsWithZone = append(argsWithZone, "--zone", zone) argsWithZone = append(argsWithZone, "--template", templateName) g.Go(func() error { - cmd := exec.Command("gcloud", argsWithZone...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", argsWithZone, output) + if _, err := cliProvider.RunCommand(context.Background(), l, argsWithZone); err != nil { + return err } return nil }) @@ -1532,10 +1502,8 @@ func waitForGroupStability(l *logger.Logger, project, groupName string, zones [] "--project", project, groupName} g.Go(func() error { - cmd := exec.Command("gcloud", groupStableArgs...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", groupStableArgs, output) + if _, err := cliProvider.RunCommand(context.Background(), l, groupStableArgs); err != nil { + return err } return nil }) @@ -1563,14 +1531,17 @@ func (p *Provider) Create( "`roachprod gc --gce-project=%s` cronjob", project) } if providerOpts.Managed { - if err := checkSDKVersion("450.0.0" /* minVersion */, "required by managed instance groups"); err != nil { + if err := checkSDKVersion(l, "450.0.0" /* minVersion */, "required by managed instance groups"); err != nil { return err } } instanceArgs, cleanUpFn, err := p.computeInstanceArgs(l, opts, providerOpts) if cleanUpFn != nil { - defer cleanUpFn() + // Keep temp. files in dry run mode. + if !config.DryRun { + defer cleanUpFn() + } } if err != nil { return err @@ -1644,10 +1615,8 @@ func (p *Provider) Create( for _, host := range zoneHosts { argsWithHost := append(argsWithZone[:len(argsWithZone):len(argsWithZone)], []string{"--instance", host}...) g.Go(func() error { - cmd := exec.Command("gcloud", argsWithHost...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", argsWithHost, output) + if _, err = p.RunCommand(context.Background(), l, argsWithHost); err != nil { + return err } return nil }) @@ -1672,10 +1641,8 @@ func (p *Provider) Create( argsWithZone := append(createArgs[:len(createArgs):len(createArgs)], "--zone", zone) argsWithZone = append(argsWithZone, zoneHosts...) g.Go(func() error { - cmd := exec.Command("gcloud", argsWithZone...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", argsWithZone, output) + if _, err := p.RunCommand(context.Background(), l, argsWithZone); err != nil { + return err } return nil }) @@ -1686,7 +1653,7 @@ func (p *Provider) Create( return err } } - return propagateDiskLabels(l, project, labels, zoneToHostNames, opts.SSDOpts.UseLocalSSD, providerOpts.PDVolumeCount) + return p.propagateDiskLabels(l, project, labels, zoneToHostNames, opts.SSDOpts.UseLocalSSD, providerOpts.PDVolumeCount) } // computeGrowDistribution computes the distribution of new nodes across the @@ -1732,12 +1699,9 @@ func (p *Provider) Shrink(l *logger.Logger, vmsToDelete vm.List, clusterName str args := []string{"compute", "instance-groups", "managed", "delete-instances", groupName, "--project", project, "--zone", zone, "--instances", strings.Join(instances, ",")} g.Go(func() error { - cmd := exec.Command("gcloud", args...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) - } - return nil + _, err := p.RunCommand(context.Background(), l, args) + + return err }) } @@ -1756,7 +1720,7 @@ func (p *Provider) Grow(l *logger.Logger, vms vm.List, clusterName string, names project := vms[0].Project groupName := instanceGroupName(clusterName) - groups, err := listManagedInstanceGroups(project, groupName) + groups, err := listManagedInstanceGroups(l, project, groupName) if err != nil { return err } @@ -1782,10 +1746,8 @@ func (p *Provider) Grow(l *logger.Logger, vms vm.List, clusterName string, names argsWithName := append(createArgs[:len(createArgs):len(createArgs)], []string{"--instance", name}...) zoneToHostNames[group.Zone] = append(zoneToHostNames[group.Zone], name) g.Go(func() error { - cmd := exec.Command("gcloud", argsWithName...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", argsWithName, output) + if _, err := p.RunCommand(context.Background(), l, argsWithName); err != nil { + return err } return nil }) @@ -1809,7 +1771,7 @@ func (p *Provider) Grow(l *logger.Logger, vms vm.List, clusterName string, names } labelsJoined += fmt.Sprintf("%s=%s", key, value) } - return propagateDiskLabels(l, project, labelsJoined, zoneToHostNames, len(vms[0].LocalDisks) != 0, + return p.propagateDiskLabels(l, project, labelsJoined, zoneToHostNames, len(vms[0].LocalDisks) != 0, len(vms[0].NonBootAttachedVolumes)) } @@ -1822,10 +1784,10 @@ type jsonBackendService struct { SelfLink string `json:"selfLink"` } -func listBackendServices(project string) ([]jsonBackendService, error) { +func listBackendServices(l *logger.Logger, project string) ([]jsonBackendService, error) { args := []string{"compute", "backend-services", "list", "--project", project, "--format", "json"} var backends []jsonBackendService - if err := runJSONCommand(args, &backends); err != nil { + if err := cliProvider.RunJSONCommand(context.Background(), l, args, &backends); err != nil { return nil, err } return backends, nil @@ -1838,10 +1800,10 @@ type jsonForwardingRule struct { Target string `json:"target"` } -func listForwardingRules(project string) ([]jsonForwardingRule, error) { +func listForwardingRules(l *logger.Logger, project string) ([]jsonForwardingRule, error) { args := []string{"compute", "forwarding-rules", "list", "--project", project, "--format", "json"} var rules []jsonForwardingRule - if err := runJSONCommand(args, &rules); err != nil { + if err := cliProvider.RunJSONCommand(context.Background(), l, args, &rules); err != nil { return nil, err } return rules, nil @@ -1853,10 +1815,10 @@ type jsonTargetTCPProxy struct { Service string `json:"service"` } -func listTargetTCPProxies(project string) ([]jsonTargetTCPProxy, error) { +func listTargetTCPProxies(l *logger.Logger, project string) ([]jsonTargetTCPProxy, error) { args := []string{"compute", "target-tcp-proxies", "list", "--project", project, "--format", "json"} var proxies []jsonTargetTCPProxy - if err := runJSONCommand(args, &proxies); err != nil { + if err := cliProvider.RunJSONCommand(context.Background(), l, args, &proxies); err != nil { return nil, err } return proxies, nil @@ -1867,10 +1829,10 @@ type jsonHealthCheck struct { SelfLink string `json:"selfLink"` } -func listHealthChecks(project string) ([]jsonHealthCheck, error) { +func listHealthChecks(l *logger.Logger, project string) ([]jsonHealthCheck, error) { args := []string{"compute", "health-checks", "list", "--project", project, "--format", "json"} var checks []jsonHealthCheck - if err := runJSONCommand(args, &checks); err != nil { + if err := cliProvider.RunJSONCommand(context.Background(), l, args, &checks); err != nil { return nil, err } return checks, nil @@ -1882,7 +1844,7 @@ func listHealthChecks(project string) ([]jsonHealthCheck, error) { // function does not return an error if the resources do not exist. Multiple // load balancers can be associated with a single cluster, so we need to delete // all of them. Health checks associated with the cluster are also deleted. -func deleteLoadBalancerResources(project, clusterName, portFilter string) error { +func deleteLoadBalancerResources(l *logger.Logger, project, clusterName, portFilter string) error { // List all the components of the load balancer resources tied to the project. var g errgroup.Group var services []jsonBackendService @@ -1890,19 +1852,19 @@ func deleteLoadBalancerResources(project, clusterName, portFilter string) error var rules []jsonForwardingRule var healthChecks []jsonHealthCheck g.Go(func() (err error) { - services, err = listBackendServices(project) + services, err = listBackendServices(l, project) return }) g.Go(func() (err error) { - proxies, err = listTargetTCPProxies(project) + proxies, err = listTargetTCPProxies(l, project) return }) g.Go(func() (err error) { - rules, err = listForwardingRules(project) + rules, err = listForwardingRules(l, project) return }) g.Go(func() (err error) { - healthChecks, err = listHealthChecks(project) + healthChecks, err = listHealthChecks(l, project) return }) if err := g.Wait(); err != nil { @@ -1962,10 +1924,8 @@ func deleteLoadBalancerResources(project, clusterName, portFilter string) error "--project", project, } g.Go(func() error { - cmd := exec.Command("gcloud", args...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + if _, err := cliProvider.RunCommand(context.Background(), l, args); err != nil { + return err } return nil }) @@ -1981,10 +1941,8 @@ func deleteLoadBalancerResources(project, clusterName, portFilter string) error "--project", project, } g.Go(func() error { - cmd := exec.Command("gcloud", args...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + if _, err := cliProvider.RunCommand(context.Background(), l, args); err != nil { + return err } return nil }) @@ -2001,10 +1959,8 @@ func deleteLoadBalancerResources(project, clusterName, portFilter string) error "--project", project, } g.Go(func() error { - cmd := exec.Command("gcloud", args...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + if _, err := cliProvider.RunCommand(context.Background(), l, args); err != nil { + return err } return nil }) @@ -2020,10 +1976,8 @@ func deleteLoadBalancerResources(project, clusterName, portFilter string) error "--project", project, } g.Go(func() error { - cmd := exec.Command("gcloud", args...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + if _, err := cliProvider.RunCommand(context.Background(), l, args); err != nil { + return err } return nil }) @@ -2032,12 +1986,12 @@ func deleteLoadBalancerResources(project, clusterName, portFilter string) error } // DeleteLoadBalancer implements the vm.Provider interface. -func (p *Provider) DeleteLoadBalancer(_ *logger.Logger, vms vm.List, port int) error { +func (p *Provider) DeleteLoadBalancer(l *logger.Logger, vms vm.List, port int) error { clusterName, err := vms[0].ClusterName() if err != nil { return err } - return deleteLoadBalancerResources(vms[0].Project, clusterName, strconv.Itoa(port)) + return deleteLoadBalancerResources(l, vms[0].Project, clusterName, strconv.Itoa(port)) } // loadBalancerNameParts returns the cluster name, resource type, and port of a @@ -2068,7 +2022,7 @@ func loadBalancerResourceName(clusterName string, port int, resourceType string) // used to support global load balancing. The different parts of the load // balancer are created sequentially, as they depend on each other. func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) error { - if err := checkSDKVersion("450.0.0" /* minVersion */, "required by load balancers"); err != nil { + if err := checkSDKVersion(l, "450.0.0" /* minVersion */, "required by load balancers"); err != nil { return err } if !isManaged(vms) { @@ -2079,7 +2033,7 @@ func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) e if err != nil { return err } - groups, err := listManagedInstanceGroups(project, instanceGroupName(clusterName)) + groups, err := listManagedInstanceGroups(l, project, instanceGroupName(clusterName)) if err != nil { return err } @@ -2089,18 +2043,17 @@ func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) e var args []string healthCheckName := loadBalancerResourceName(clusterName, port, "health-check") - output, err := func() ([]byte, error) { + _, err = func() ([]byte, error) { defer ui.NewDefaultSpinner(l, "create health check").Start()() args = []string{"compute", "health-checks", "create", "tcp", healthCheckName, "--project", project, "--port", strconv.Itoa(port), } - cmd := exec.Command("gcloud", args...) - return cmd.CombinedOutput() + return p.RunCommand(context.Background(), l, args) }() if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + return err } loadBalancerName := loadBalancerResourceName(clusterName, port, "load-balancer") @@ -2114,19 +2067,18 @@ func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) e "--timeout", "5m", "--port-name", "cockroach", } - output, err = func() ([]byte, error) { + _, err = func() ([]byte, error) { defer ui.NewDefaultSpinner(l, "creating load balancer backend").Start()() - cmd := exec.Command("gcloud", args...) - return cmd.CombinedOutput() + return p.RunCommand(context.Background(), l, args) }() if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + return err } // Add the instance group to the backend service. This has to be done // sequentially, and for each zone, because gcloud does not allow adding // multiple instance groups in parallel. - output, err = func() ([]byte, error) { + _, err = func() ([]byte, error) { spinner := ui.NewDefaultCountingSpinner(l, "adding backends to load balancer", len(groups)) defer spinner.Start()() for n, group := range groups { @@ -2138,8 +2090,7 @@ func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) e "--balancing-mode", "UTILIZATION", "--max-utilization", "0.8", } - cmd := exec.Command("gcloud", args...) - output, err = cmd.CombinedOutput() + output, err := p.RunCommand(context.Background(), l, args) if err != nil { return output, err } @@ -2148,25 +2099,24 @@ func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) e return nil, nil }() if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + return err } proxyName := loadBalancerResourceName(clusterName, port, "proxy") - output, err = func() ([]byte, error) { + _, err = func() ([]byte, error) { defer ui.NewDefaultSpinner(l, "creating load balancer proxy").Start()() args = []string{"compute", "target-tcp-proxies", "create", proxyName, "--project", project, "--backend-service", loadBalancerName, "--proxy-header", "NONE", } - cmd := exec.Command("gcloud", args...) - return cmd.CombinedOutput() + return p.RunCommand(context.Background(), l, args) }() if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + return err } - output, err = func() ([]byte, error) { + _, err = func() ([]byte, error) { defer ui.NewDefaultSpinner(l, "creating load balancer forwarding rule").Start()() args = []string{"compute", "forwarding-rules", "create", loadBalancerResourceName(clusterName, port, "forwarding-rule"), @@ -2175,11 +2125,10 @@ func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) e "--target-tcp-proxy", proxyName, "--ports", strconv.Itoa(port), } - cmd := exec.Command("gcloud", args...) - return cmd.CombinedOutput() + return p.RunCommand(context.Background(), l, args) }() if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + return err } // Named ports can be set in parallel for all instance groups. @@ -2191,10 +2140,8 @@ func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) e "--named-ports", "cockroach:" + strconv.Itoa(port), } g.Go(func() error { - cmd := exec.Command("gcloud", groupArgs...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", groupArgs, output) + if _, err := p.RunCommand(context.Background(), l, groupArgs); err != nil { + return err } return nil }) @@ -2205,7 +2152,7 @@ func (p *Provider) CreateLoadBalancer(l *logger.Logger, vms vm.List, port int) e // ListLoadBalancers returns the list of load balancers associated with the // given VMs. The VMs have to be part of a managed instance group. The load // balancers are returned as a list of service addresses. -func (p *Provider) ListLoadBalancers(_ *logger.Logger, vms vm.List) ([]vm.ServiceAddress, error) { +func (p *Provider) ListLoadBalancers(l *logger.Logger, vms vm.List) ([]vm.ServiceAddress, error) { // Only managed instance groups support load balancers. if !isManaged(vms) { return nil, nil @@ -2215,7 +2162,7 @@ func (p *Provider) ListLoadBalancers(_ *logger.Logger, vms vm.List) ([]vm.Servic if err != nil { return nil, err } - rules, err := listForwardingRules(project) + rules, err := listForwardingRules(l, project) if err != nil { return nil, err } @@ -2287,7 +2234,7 @@ func AllowedLocalSSDCount(machineType string) ([]int, error) { // N.B. neither boot disk nor additional persistent disks are assigned VM labels by default. // Hence, we must propagate them. See: https://cloud.google.com/compute/docs/labeling-resources#labeling_boot_disks -func propagateDiskLabels( +func (p *Provider) propagateDiskLabels( l *logger.Logger, project string, labels string, @@ -2313,13 +2260,9 @@ func propagateDiskLabels( bootDiskArgs = append(bootDiskArgs, zoneArg...) // N.B. boot disk has the same name as the host. bootDiskArgs = append(bootDiskArgs, hostName) - cmd := exec.Command("gcloud", bootDiskArgs...) + _, err := p.RunCommand(context.Background(), l, bootDiskArgs) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", bootDiskArgs, output) - } - return nil + return err }) if !useLocalSSD { @@ -2334,11 +2277,8 @@ func propagateDiskLabels( persistentDiskArgs = append(persistentDiskArgs, zoneArg...) // N.B. additional persistent disks are suffixed with the offset, starting at 1. persistentDiskArgs = append(persistentDiskArgs, fmt.Sprintf("%s-%d", hostName, offset)) - cmd := exec.Command("gcloud", persistentDiskArgs...) - - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", persistentDiskArgs, output) + if _, err := p.RunCommand(context.Background(), l, persistentDiskArgs); err != nil { + return err } } return nil @@ -2358,10 +2298,10 @@ type jsonInstanceTemplate struct { // listInstanceTemplates returns a list of instance templates for a given // project. -func listInstanceTemplates(project string) ([]jsonInstanceTemplate, error) { +func listInstanceTemplates(l *logger.Logger, project string) ([]jsonInstanceTemplate, error) { args := []string{"compute", "instance-templates", "list", "--project", project, "--format", "json"} var templates []jsonInstanceTemplate - if err := runJSONCommand(args, &templates); err != nil { + if err := cliProvider.RunJSONCommand(context.Background(), l, args, &templates, cli.AlwaysExecute()); err != nil { return nil, err } return templates, nil @@ -2376,23 +2316,23 @@ type jsonManagedInstanceGroup struct { // listManagedInstanceGroups returns a list of managed instance groups for a // given group name. Groups may exist in multiple zones with the same name. This // function returns a list of all groups with the given name. -func listManagedInstanceGroups(project, groupName string) ([]jsonManagedInstanceGroup, error) { +func listManagedInstanceGroups( + l *logger.Logger, project, groupName string, +) ([]jsonManagedInstanceGroup, error) { args := []string{"compute", "instance-groups", "list", "--only-managed", "--project", project, "--format", "json", "--filter", fmt.Sprintf("name=%s", groupName)} var groups []jsonManagedInstanceGroup - if err := runJSONCommand(args, &groups); err != nil { + if err := cliProvider.RunJSONCommand(context.Background(), l, args, &groups, cli.AlwaysExecute()); err != nil { return nil, err } return groups, nil } // deleteInstanceTemplate deletes the instance template for the cluster. -func deleteInstanceTemplate(project, templateName string) error { +func deleteInstanceTemplate(l *logger.Logger, project, templateName string) error { args := []string{"compute", "instance-templates", "delete", "--project", project, "--quiet", templateName} - cmd := exec.Command("gcloud", args...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) + if _, err := cliProvider.RunCommand(context.Background(), l, args); err != nil { + return err } return nil } @@ -2432,12 +2372,12 @@ func (p *Provider) deleteManaged(l *logger.Logger, vms vm.List) error { // Delete any load balancer resources associated with the cluster. Trying to // delete the instance group before the load balancer resources will result // in an error. - err := deleteLoadBalancerResources(project, cluster, "" /* portFilter */) + err := deleteLoadBalancerResources(l, project, cluster, "" /* portFilter */) if err != nil { return err } // Multiple instance groups can exist for a single cluster, one for each zone. - projectGroups, err := listManagedInstanceGroups(project, instanceGroupName(cluster)) + projectGroups, err := listManagedInstanceGroups(l, project, instanceGroupName(cluster)) if err != nil { return err } @@ -2447,10 +2387,8 @@ func (p *Provider) deleteManaged(l *logger.Logger, vms vm.List) error { "--zone", group.Zone, group.Name} g.Go(func() error { - cmd := exec.Command("gcloud", argsWithZone...) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", argsWithZone, output) + if _, err := p.RunCommand(context.Background(), l, argsWithZone); err != nil { + return err } return nil }) @@ -2465,7 +2403,7 @@ func (p *Provider) deleteManaged(l *logger.Logger, vms vm.List) error { // deleted. g = errgroup.Group{} for cluster, project := range clusterProjectMap { - templates, err := listInstanceTemplates(project) + templates, err := listInstanceTemplates(l, project) if err != nil { return err } @@ -2475,7 +2413,7 @@ func (p *Provider) deleteManaged(l *logger.Logger, vms vm.List) error { continue } g.Go(func() error { - return deleteInstanceTemplate(project, template.Name) + return deleteInstanceTemplate(l, project, template.Name) }) } } @@ -2509,13 +2447,9 @@ func (p *Provider) deleteUnmanaged(l *logger.Logger, vms vm.List) error { args = append(args, names...) g.Go(func() error { - cmd := exec.CommandContext(ctx, "gcloud", args...) + _, err := p.RunCommand(ctx, l, args) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) - } - return nil + return err }) } } @@ -2552,13 +2486,9 @@ func (p *Provider) Reset(l *logger.Logger, vms vm.List) error { args = append(args, names...) g.Go(func() error { - cmd := exec.CommandContext(ctx, "gcloud", args...) + _, err := p.RunCommand(ctx, l, args) - output, err := cmd.CombinedOutput() - if err != nil { - return errors.Wrapf(err, "Command: gcloud %s\nOutput: %s", args, output) - } - return nil + return err }) } } @@ -2578,7 +2508,7 @@ func (p *Provider) FindActiveAccount(l *logger.Logger) (string, error) { args := []string{"auth", "list", "--format", "json", "--filter", "status~ACTIVE"} accounts := make([]jsonAuth, 0) - if err := runJSONCommand(args, &accounts); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &accounts, cli.AlwaysExecute()); err != nil { return "", err } @@ -2609,7 +2539,7 @@ func (p *Provider) List(l *logger.Logger, opts vm.ListOptions) (vm.List, error) // Run the command, extracting the JSON payload jsonVMS := make([]jsonVM, 0) - if err := runJSONCommand(args, &jsonVMS); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &jsonVMS, cli.AlwaysExecute()); err != nil { return nil, err } @@ -2630,7 +2560,7 @@ func (p *Provider) List(l *logger.Logger, opts vm.ListOptions) (vm.List, error) } var disks []describeVolumeCommandResponse - if err := runJSONCommand(args, &disks); err != nil { + if err := p.RunJSONCommand(context.Background(), l, args, &disks, cli.AlwaysExecute()); err != nil { return nil, err } @@ -2680,7 +2610,7 @@ func (p *Provider) List(l *logger.Logger, opts vm.ListOptions) (vm.List, error) if projTemplatesInUse == nil { projTemplatesInUse = make(map[string]struct{}) } - templates, err := listInstanceTemplates(prj) + templates, err := listInstanceTemplates(l, prj) if err != nil { return nil, err } @@ -2941,11 +2871,11 @@ func lastComponent(url string) string { // checkSDKVersion checks that the gcloud SDK version is at least minVersion. // If it is not, it returns an error with the given message. -func checkSDKVersion(minVersion, message string) error { +func checkSDKVersion(l *logger.Logger, minVersion, message string) error { var jsonVersion struct { GoogleCloudSDK string `json:"Google Cloud SDK"` } - err := runJSONCommand([]string{"version", "--format", "json"}, &jsonVersion) + err := cliProvider.RunJSONCommand(context.Background(), l, []string{"version", "--format", "json"}, &jsonVersion) if err != nil { return err } diff --git a/pkg/roachprod/vm/gce/testutils/BUILD.bazel b/pkg/roachprod/vm/gce/testutils/BUILD.bazel index 7787fdf79ef1..376784f11c45 100644 --- a/pkg/roachprod/vm/gce/testutils/BUILD.bazel +++ b/pkg/roachprod/vm/gce/testutils/BUILD.bazel @@ -6,7 +6,9 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/vm/gce/testutils", visibility = ["//visibility:public"], deps = [ + "//pkg/roachprod/logger", "//pkg/roachprod/vm", + "//pkg/roachprod/vm/cli", "//pkg/roachprod/vm/gce", "//pkg/util/syncutil", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/roachprod/vm/gce/testutils/dns_server.go b/pkg/roachprod/vm/gce/testutils/dns_server.go index 9b655a938a81..34e4691ad43f 100644 --- a/pkg/roachprod/vm/gce/testutils/dns_server.go +++ b/pkg/roachprod/vm/gce/testutils/dns_server.go @@ -6,13 +6,15 @@ package testutils import ( + "context" "encoding/json" "fmt" "math/rand" - "os/exec" "strings" + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli" "github.com/cockroachdb/cockroach/pkg/roachprod/vm/gce" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/errors" @@ -137,7 +139,9 @@ func (t *testDNSServer) Metrics() Metrics { return t.metrics } -func (t *testDNSServer) execFunc(cmd *exec.Cmd) ([]byte, error) { +func (t *testDNSServer) execFunc( + ctx context.Context, l *logger.Logger, args []string, opts ...cli.Option, +) ([]byte, error) { getArg := func(args []string, arg string) string { for i, a := range args { if a == arg { @@ -146,16 +150,16 @@ func (t *testDNSServer) execFunc(cmd *exec.Cmd) ([]byte, error) { } return "" } - for _, arg := range cmd.Args { + for _, arg := range args { switch arg { case "list": - return t.list(getArg(cmd.Args, "--filter")) + return t.list(getArg(args, "--filter")) case "create": - return t.create(getArg(cmd.Args, "create"), getArg(cmd.Args, "--rrdatas")) + return t.create(getArg(args, "create"), getArg(args, "--rrdatas")) case "update": - return t.update(getArg(cmd.Args, "update"), getArg(cmd.Args, "--rrdatas")) + return t.update(getArg(args, "update"), getArg(args, "--rrdatas")) case "delete": - return t.delete(getArg(cmd.Args, "delete")) + return t.delete(getArg(args, "delete")) } } return nil, errors.New("unknown command") diff --git a/pkg/roachprod/vm/gce/utils.go b/pkg/roachprod/vm/gce/utils.go index bca90f3b87a4..b2ec35e1e362 100644 --- a/pkg/roachprod/vm/gce/utils.go +++ b/pkg/roachprod/vm/gce/utils.go @@ -8,6 +8,7 @@ package gce import ( "bufio" "bytes" + "context" "fmt" "os" "os/exec" @@ -19,10 +20,15 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/config" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm/cli" "github.com/cockroachdb/errors" "golang.org/x/crypto/ssh" ) +var cliProvider = cli.CLIProvider{ + CLICommand: "gcloud", +} + const gceDiskStartupScriptTemplate = `#!/usr/bin/env bash # Script for setting up a GCE machine for roachprod use. @@ -420,7 +426,7 @@ func (ak AuthorizedKeys) AsProjectMetadata() []byte { // GetUserAuthorizedKeys retrieves reads a list of user public keys from the // gcloud cockroach-ephemeral project and returns them formatted for use in // an authorized_keys file. -func GetUserAuthorizedKeys() (AuthorizedKeys, error) { +func GetUserAuthorizedKeys(l *logger.Logger) (AuthorizedKeys, error) { var outBuf bytes.Buffer // The below command will return a stream of user:pubkey as text. cmd := exec.Command("gcloud", "compute", "project-info", "describe", @@ -429,6 +435,8 @@ func GetUserAuthorizedKeys() (AuthorizedKeys, error) { cmd.Stderr = os.Stderr cmd.Stdout = &outBuf + cli.MaybeLogCmd(context.Background(), l, cmd) + if err := cmd.Run(); err != nil { return nil, err } @@ -475,8 +483,8 @@ func GetUserAuthorizedKeys() (AuthorizedKeys, error) { // keys installed on clusters managed by roachprod. Currently, these // keys are stored in the project metadata for the roachprod's // `DefaultProject`. -func AddUserAuthorizedKey(ak AuthorizedKey) error { - existingKeys, err := GetUserAuthorizedKeys() +func AddUserAuthorizedKey(l *logger.Logger, ak AuthorizedKey) error { + existingKeys, err := GetUserAuthorizedKeys(l) if err != nil { return err } diff --git a/pkg/roachprod/vm/local/dns.go b/pkg/roachprod/vm/local/dns.go index 01ce8378a3c9..1f7842212285 100644 --- a/pkg/roachprod/vm/local/dns.go +++ b/pkg/roachprod/vm/local/dns.go @@ -15,6 +15,7 @@ import ( "regexp" "github.com/cockroachdb/cockroach/pkg/roachprod/lock" + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/errors" "github.com/cockroachdb/errors/oserror" @@ -40,7 +41,9 @@ func (n *dnsProvider) Domain() string { } // CreateRecords is part of the vm.DNSProvider interface. -func (n *dnsProvider) CreateRecords(_ context.Context, records ...vm.DNSRecord) error { +func (n *dnsProvider) CreateRecords( + _ context.Context, _ *logger.Logger, records ...vm.DNSRecord, +) error { unlock, err := lock.AcquireFilesystemLock(n.lockFilePath) if err != nil { return err @@ -59,7 +62,9 @@ func (n *dnsProvider) CreateRecords(_ context.Context, records ...vm.DNSRecord) } // LookupSRVRecords is part of the vm.DNSProvider interface. -func (n *dnsProvider) LookupSRVRecords(_ context.Context, name string) ([]vm.DNSRecord, error) { +func (n *dnsProvider) LookupSRVRecords( + _ context.Context, _ *logger.Logger, name string, +) ([]vm.DNSRecord, error) { records, err := n.loadRecords() if err != nil { return nil, err @@ -74,7 +79,7 @@ func (n *dnsProvider) LookupSRVRecords(_ context.Context, name string) ([]vm.DNS } // ListRecords is part of the vm.DNSProvider interface. -func (n *dnsProvider) ListRecords(_ context.Context) ([]vm.DNSRecord, error) { +func (n *dnsProvider) ListRecords(_ context.Context, _ *logger.Logger) ([]vm.DNSRecord, error) { records, err := n.loadRecords() if err != nil { return nil, err @@ -83,7 +88,9 @@ func (n *dnsProvider) ListRecords(_ context.Context) ([]vm.DNSRecord, error) { } // DeleteRecordsByName is part of the vm.DNSProvider interface. -func (n *dnsProvider) DeleteRecordsByName(_ context.Context, names ...string) error { +func (n *dnsProvider) DeleteRecordsByName( + _ context.Context, _ *logger.Logger, names ...string, +) error { unlock, err := lock.AcquireFilesystemLock(n.lockFilePath) if err != nil { return err @@ -101,7 +108,9 @@ func (n *dnsProvider) DeleteRecordsByName(_ context.Context, names ...string) er } // DeleteRecordsBySubdomain is part of the vm.DNSProvider interface. -func (n *dnsProvider) DeleteRecordsBySubdomain(_ context.Context, subdomain string) error { +func (n *dnsProvider) DeleteRecordsBySubdomain( + _ context.Context, _ *logger.Logger, subdomain string, +) error { unlock, err := lock.AcquireFilesystemLock(n.lockFilePath) if err != nil { return err diff --git a/pkg/roachprod/vm/local/dns_test.go b/pkg/roachprod/vm/local/dns_test.go index 2ada05607b9b..d196388c8072 100644 --- a/pkg/roachprod/vm/local/dns_test.go +++ b/pkg/roachprod/vm/local/dns_test.go @@ -31,7 +31,7 @@ func createTestDNSRecords(testRecords ...dnsTestRec) []vm.DNSRecord { func createTestDNSProvider(t *testing.T, testRecords ...dnsTestRec) vm.DNSProvider { p := NewDNSProvider(t.TempDir(), "local-zone") - err := p.CreateRecords(context.Background(), createTestDNSRecords(testRecords...)...) + err := p.CreateRecords(context.Background(), nil, createTestDNSRecords(testRecords...)...) require.NoError(t, err) return p } @@ -48,7 +48,7 @@ func TestLookupRecords(t *testing.T) { }...) t.Run("lookup system", func(t *testing.T) { - records, err := p.LookupSRVRecords(ctx, "_system-sql._tcp.local.local-zone") + records, err := p.LookupSRVRecords(ctx, nil, "_system-sql._tcp.local.local-zone") require.NoError(t, err) require.Equal(t, 3, len(records)) for _, r := range records { @@ -58,7 +58,7 @@ func TestLookupRecords(t *testing.T) { }) t.Run("parse SRV data", func(t *testing.T) { - records, err := p.LookupSRVRecords(ctx, "_tenant-1-sql._tcp.local.local-zone") + records, err := p.LookupSRVRecords(ctx, nil, "_tenant-1-sql._tcp.local.local-zone") require.NoError(t, err) require.Equal(t, 1, len(records)) data, err := records[0].ParseSRVRecord() diff --git a/pkg/roachprod/vm/local/local.go b/pkg/roachprod/vm/local/local.go index 290803ceaba3..af026f477a0b 100644 --- a/pkg/roachprod/vm/local/local.go +++ b/pkg/roachprod/vm/local/local.go @@ -79,8 +79,13 @@ func DeleteCluster(l *logger.Logger, name string) error { for i := range c.VMs { path := VMDir(c.Name, i+1) - if err := os.RemoveAll(path); err != nil { - return err + if config.DryRun || config.Verbose { + l.Printf("exec: rm -rf %s", path) + } + if !config.DryRun { + if err := os.RemoveAll(path); err != nil { + return err + } } } @@ -93,7 +98,7 @@ func DeleteCluster(l *logger.Logger, name string) error { // Local clusters are expected to specifically use the local DNS provider // implementation, and should clean up any DNS records in the local file // system cache. - return p.DeleteRecordsBySubdomain(context.Background(), c.Name) + return p.DeleteRecordsBySubdomain(context.Background(), l, c.Name) } // Clusters returns a list of all known local clusters. @@ -219,7 +224,9 @@ func (p *Provider) ListLoadBalancers(*logger.Logger, vm.List) ([]vm.ServiceAddre return nil, nil } -func (p *Provider) createVM(clusterName string, index int, creationTime time.Time) (vm.VM, error) { +func (p *Provider) createVM( + l *logger.Logger, clusterName string, index int, creationTime time.Time, +) (vm.VM, error) { cVM := vm.VM{ Name: "localhost", CreatedAt: creationTime, @@ -237,9 +244,14 @@ func (p *Provider) createVM(clusterName string, index int, creationTime time.Tim LocalClusterName: clusterName, } path := VMDir(clusterName, index+1) - err := os.MkdirAll(path, 0755) - if err != nil { - return vm.VM{}, err + if config.DryRun || config.Verbose { + l.Printf("exec: mkdir -p %s", path) + } + if !config.DryRun { + err := os.MkdirAll(path, 0755) + if err != nil { + return vm.VM{}, err + } } return cVM, nil } @@ -262,7 +274,7 @@ func (p *Provider) Create( for i := range names { var err error - c.VMs[i], err = p.createVM(c.Name, i, now) + c.VMs[i], err = p.createVM(l, c.Name, i, now) if err != nil { return err } @@ -278,7 +290,7 @@ func (p *Provider) Grow(l *logger.Logger, vms vm.List, clusterName string, names now := timeutil.Now() offset := p.clusters[clusterName].VMs.Len() for i := range names { - cVM, err := p.createVM(clusterName, i+offset, now) + cVM, err := p.createVM(l, clusterName, i+offset, now) if err != nil { return err } @@ -294,8 +306,13 @@ func (p *Provider) Shrink(l *logger.Logger, vmsToDelete vm.List, clusterName str continue } path := VMDir(clusterName, i+1) - if err := os.RemoveAll(path); err != nil { - return err + if config.DryRun || config.Verbose { + l.Printf("exec: rm -rf %s", path) + } + if !config.DryRun { + if err := os.RemoveAll(path); err != nil { + return err + } } } p.clusters[clusterName].VMs = p.clusters[clusterName].VMs[:keepCount]