Skip to content

Commit

Permalink
Merge pull request #6 from NVIDIA/sshUser
Browse files Browse the repository at this point in the history
Add UseName to Auth parameter
  • Loading branch information
ArangoGutierrez authored Feb 2, 2024
2 parents 45c9212 + ce9f741 commit 85fc4dd
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 17 deletions.
2 changes: 2 additions & 0 deletions api/holodeck/v1alpha1/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ type Properties struct {

type Auth struct {
KeyName string `json:"keyName"`
// Username for the SSH connection
Username string `json:"username"`
// Path to the public key file on the local machine
PublicKey string `json:"publicKey"`
// Path to the private key file on the local machine
Expand Down
16 changes: 15 additions & 1 deletion cmd/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ func (m command) build() *cli.Command {
opts.cfg.Spec.ContainerRuntime.Name = v1alpha1.ContainerRuntimeNone
}

// If no username is specified, default to ubuntu
if opts.cfg.Spec.Auth.Username == "" {
// TODO (ArangoGutierrez): This should be based on the OS
// Amazon Linux: ec2-user
// Ubuntu: ubuntu
// CentOS: centos
// Debian: admin
// RHEL: ec2-user
// Fedora: ec2-user
// SUSE: ec2-user

opts.cfg.Spec.Auth.Username = "ubuntu"
}

return nil
},
Action: func(c *cli.Context) error {
Expand Down Expand Up @@ -159,7 +173,7 @@ func runProvision(log *logger.FunLogger, opts *options) error {
hostUrl = opts.cfg.Spec.Instance.HostUrl
}

p, err := provisioner.New(log, opts.cfg.Spec.Auth.PrivateKey, hostUrl)
p, err := provisioner.New(log, opts.cfg.Spec.Auth.PrivateKey, opts.cfg.Spec.Auth.Username, hostUrl)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions cmd/dryrun/dryrun.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (m command) run(c *cli.Context, opts *options) error {
return err
}
case v1alpha1.ProviderSSH:
if err := connectOrDie(opts.cfg.Spec.Auth.PrivateKey, opts.cfg.Spec.Instance.HostUrl); err != nil {
if err := connectOrDie(opts.cfg.Spec.Auth.PrivateKey, opts.cfg.Spec.Username, opts.cfg.Spec.Instance.HostUrl); err != nil {
return err
}
default:
Expand Down Expand Up @@ -124,7 +124,7 @@ func validateAWS(log *logger.FunLogger, opts *options) error {
}

// createSshClient creates a ssh client, and retries if it fails to connect
func connectOrDie(keyPath, hostUrl string) error {
func connectOrDie(keyPath, userName, hostUrl string) error {
var err error
key, err := os.ReadFile(keyPath)
if err != nil {
Expand All @@ -135,7 +135,7 @@ func connectOrDie(keyPath, hostUrl string) error {
return fmt.Errorf("failed to parse private key: %v", err)
}
sshConfig := &ssh.ClientConfig{
User: "ubuntu",
User: userName,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
Expand Down
28 changes: 15 additions & 13 deletions pkg/provisioner/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,27 @@ type Provisioner struct {
Client *ssh.Client
SessionManager *ssm.Client

HostUrl string
KeyPath string
tpl bytes.Buffer
HostUrl string
UserName string
KeyPath string
tpl bytes.Buffer

log *logger.FunLogger
}

func New(log *logger.FunLogger, keyPath, hostUrl string) (*Provisioner, error) {
client, err := connectOrDie(keyPath, hostUrl)
func New(log *logger.FunLogger, keyPath, userName, hostUrl string) (*Provisioner, error) {
client, err := connectOrDie(keyPath, userName, hostUrl)
if err != nil {
return nil, fmt.Errorf("failed to connect to %s: %v", hostUrl, err)
}

p := &Provisioner{
Client: client,
HostUrl: hostUrl,
KeyPath: keyPath,
tpl: bytes.Buffer{},
log: log,
Client: client,
HostUrl: hostUrl,
UserName: userName,
KeyPath: keyPath,
tpl: bytes.Buffer{},
log: log,
}

return p, nil
Expand Down Expand Up @@ -121,7 +123,7 @@ func (p *Provisioner) provision() error {
var err error

// Create a new ssh connection
p.Client, err = connectOrDie(p.KeyPath, p.HostUrl)
p.Client, err = connectOrDie(p.KeyPath, p.UserName, p.HostUrl)
if err != nil {
return fmt.Errorf("failed to connect to %s: %v", p.HostUrl, err)
}
Expand Down Expand Up @@ -221,7 +223,7 @@ func addScriptHeader(tpl *bytes.Buffer) error {
}

// createSshClient creates a ssh client, and retries if it fails to connect
func connectOrDie(keyPath, hostUrl string) (*ssh.Client, error) {
func connectOrDie(keyPath, userName, hostUrl string) (*ssh.Client, error) {
var client *ssh.Client
var err error
key, err := os.ReadFile(keyPath)
Expand All @@ -233,7 +235,7 @@ func connectOrDie(keyPath, hostUrl string) (*ssh.Client, error) {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}
sshConfig := &ssh.ClientConfig{
User: "ubuntu",
User: userName,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
Expand Down

0 comments on commit 85fc4dd

Please sign in to comment.