Skip to content

Commit

Permalink
Merge pull request #453 from gravitational/sasha/vars
Browse files Browse the repository at this point in the history
add support for passing env variables, fixes #451
  • Loading branch information
kontsevoy authored Jun 10, 2016
2 parents a637a7d + b3a105e commit 741a70a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
20 changes: 20 additions & 0 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,26 @@ func (s *IntSuite) TestInteractive(c *check.C) {
c.Assert(string(personA.Output(100)), check.DeepEquals, string(personB.Output(100)))
}

// TestInvalidLogins validates that you can't login with invalid login or
// with invalid 'site' parameter
func (s *IntSuite) TestEnvironmentVariables(c *check.C) {
t := s.newTeleport(c, nil, true)
defer t.Stop(true)

testKey, testVal := "TELEPORT_TEST_ENV", "howdy"
cmd := []string{"echo", fmt.Sprintf("$%v", testKey)}

// make sure sessions set run command
tc, err := t.NewClient(s.me.Username, Site, Host, t.GetPortSSHInt())
tc.Env = map[string]string{testKey: testVal}
out := &bytes.Buffer{}
tc.Stdout = out
c.Assert(err, check.IsNil)
err = tc.SSH(cmd, false, nil)
c.Assert(err, check.IsNil)
c.Assert(strings.TrimSpace(out.String()), check.Equals, testVal)
}

// TestInvalidLogins validates that you can't login with invalid login or
// with invalid 'site' parameter
func (s *IntSuite) TestInvalidLogins(c *check.C) {
Expand Down
7 changes: 5 additions & 2 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ type Config struct {
// KeyDir defines where temporary session keys will be stored.
// if empty, they'll go to ~/.tsh
KeysDir string

// Env is a map of environmnent variables to send when opening session
Env map[string]string
}

// ProxyHostPort returns a full host:port address of the proxy or an empty string if no
Expand Down Expand Up @@ -618,7 +621,7 @@ func (tc *TeleportClient) runCommand(siteName string, nodeAddresses []string, pr
if len(nodeAddresses) > 1 {
fmt.Printf("Running command on %v:\n", address)
}
err = nodeClient.Run(command, stdin, tc.Stdout, tc.Stderr)
err = nodeClient.Run(command, stdin, tc.Stdout, tc.Stderr, tc.Config.Env)
if err != nil {
exitErr, ok := err.(*ssh.ExitError)
if ok {
Expand Down Expand Up @@ -690,7 +693,7 @@ func (tc *TeleportClient) runShell(nodeClient *NodeClient, sessionID session.ID,
winSize = &term.Winsize{Width: 80, Height: 25}
}

shell, err := nodeClient.Shell(int(winSize.Width), int(winSize.Height), sessionID)
shell, err := nodeClient.Shell(int(winSize.Width), int(winSize.Height), sessionID, tc.Config.Env)
if err != nil {
return trace.Wrap(err)
}
Expand Down
19 changes: 17 additions & 2 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func (proxy *ProxyClient) Close() error {
}

// Shell returns remote shell as io.ReadWriterCloser object
func (client *NodeClient) Shell(width, height int, sessionID session.ID) (io.ReadWriteCloser, error) {
func (client *NodeClient) Shell(width, height int, sessionID session.ID, env map[string]string) (io.ReadWriteCloser, error) {
if sessionID == "" {
// initiate a new session if not passed
sessionID = session.NewID()
Expand Down Expand Up @@ -294,6 +294,14 @@ func (client *NodeClient) Shell(width, height int, sessionID session.ID) (io.Rea
}
}

// pass environment variables set by client
for key, val := range env {
err = clientSession.Setenv(key, val)
if err != nil {
log.Warn(err)
}
}

terminalModes := ssh.TerminalModes{}

err = clientSession.RequestPty("xterm", height, width, terminalModes)
Expand Down Expand Up @@ -414,11 +422,18 @@ func (client *NodeClient) Shell(width, height int, sessionID session.ID) (io.Rea

// Run executes command on the remote server and writes its stdout to
// the 'output' argument
func (client *NodeClient) Run(cmd []string, stdin io.Reader, stdout, stderr io.Writer) error {
func (client *NodeClient) Run(cmd []string, stdin io.Reader, stdout, stderr io.Writer, env map[string]string) error {
session, err := client.Client.NewSession()
if err != nil {
return trace.Wrap(err)
}
// pass environment variables set by client
for key, val := range env {
err = session.Setenv(key, val)
if err != nil {
log.Warn(err)
}
}
session.Stdout = stdout
session.Stderr = stderr
session.Stdin = stdin
Expand Down

0 comments on commit 741a70a

Please sign in to comment.