Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the possibility to ignore the Match directive #65

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 40 additions & 34 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// the host name to match on ("example.com"), and the second argument is the key
// you want to retrieve ("Port"). The keywords are case insensitive.
//
// port := ssh_config.Get("myhost", "Port")
// port := ssh_config.Get("myhost", "Port")
//
// You can also manipulate an SSH config file and then print it or write it back
// to disk.
Expand Down Expand Up @@ -52,15 +52,16 @@ type configFinder func() string
// UserSettings checks ~/.ssh and /etc/ssh for configuration files. The config
// files are parsed and cached the first time Get() or GetStrict() is called.
type UserSettings struct {
IgnoreErrors bool
customConfig *Config
customConfigFinder configFinder
systemConfig *Config
systemConfigFinder configFinder
userConfig *Config
userConfigFinder configFinder
loadConfigs sync.Once
onceErr error
IgnoreErrors bool
IgnoreMatchDirective bool
customConfig *Config
customConfigFinder configFinder
systemConfig *Config
systemConfigFinder configFinder
userConfig *Config
userConfigFinder configFinder
loadConfigs sync.Once
onceErr error
}

func homedir() string {
Expand All @@ -80,9 +81,10 @@ func userConfigFinder() string {
// GetStrict. It checks both $HOME/.ssh/config and /etc/ssh/ssh_config for keys,
// and it will return parse errors (if any) instead of swallowing them.
var DefaultUserSettings = &UserSettings{
IgnoreErrors: false,
systemConfigFinder: systemConfigFinder,
userConfigFinder: userConfigFinder,
IgnoreErrors: false,
IgnoreMatchDirective: false,
systemConfigFinder: systemConfigFinder,
userConfigFinder: userConfigFinder,
}

func systemConfigFinder() string {
Expand Down Expand Up @@ -277,10 +279,11 @@ func (u *UserSettings) doLoadConfigs() {
var err error
if u.customConfigFinder != nil {
filename = u.customConfigFinder()
u.customConfig, err = parseFile(filename)
u.customConfig, err = parseFile(filename, u.IgnoreMatchDirective)
// IsNotExist should be returned because a user specified this
// function - not existing likely means they made an error
if err != nil {
// We should also respect the ignore flag
if err != nil && !u.IgnoreErrors {
u.onceErr = err
}
return
Expand All @@ -290,7 +293,7 @@ func (u *UserSettings) doLoadConfigs() {
} else {
filename = u.userConfigFinder()
}
u.userConfig, err = parseFile(filename)
u.userConfig, err = parseFile(filename, u.IgnoreMatchDirective)
//lint:ignore S1002 I prefer it this way
if err != nil && os.IsNotExist(err) == false {
u.onceErr = err
Expand All @@ -301,25 +304,26 @@ func (u *UserSettings) doLoadConfigs() {
} else {
filename = u.systemConfigFinder()
}
u.systemConfig, err = parseFile(filename)
u.systemConfig, err = parseFile(filename, u.IgnoreMatchDirective)
//lint:ignore S1002 I prefer it this way
if err != nil && os.IsNotExist(err) == false {
u.onceErr = err
return
}
})
},
)
}

func parseFile(filename string) (*Config, error) {
return parseWithDepth(filename, 0)
func parseFile(filename string, ignoreMatchDirective bool) (*Config, error) {
return parseWithDepth(filename, ignoreMatchDirective, 0)
}

func parseWithDepth(filename string, depth uint8) (*Config, error) {
func parseWithDepth(filename string, ignoreMatchDirective bool, depth uint8) (*Config, error) {
b, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return decodeBytes(b, isSystem(filename), depth)
return decodeBytes(b, isSystem(filename), ignoreMatchDirective, depth)
}

func isSystem(filename string) bool {
Expand All @@ -329,21 +333,21 @@ func isSystem(filename string) bool {

// Decode reads r into a Config, or returns an error if r could not be parsed as
// an SSH config file.
func Decode(r io.Reader) (*Config, error) {
func Decode(r io.Reader, ignoreMatchDirective bool) (*Config, error) {
b, err := io.ReadAll(r)
if err != nil {
return nil, err
}
return decodeBytes(b, false, 0)
return decodeBytes(b, false, ignoreMatchDirective, 0)
iamFrancescoFerro marked this conversation as resolved.
Show resolved Hide resolved
}

// DecodeBytes reads b into a Config, or returns an error if r could not be
// parsed as an SSH config file.
func DecodeBytes(b []byte) (*Config, error) {
return decodeBytes(b, false, 0)
func DecodeBytes(b []byte, ignoreMatchDirective bool) (*Config, error) {
return decodeBytes(b, false, ignoreMatchDirective, 0)
}

func decodeBytes(b []byte, system bool, depth uint8) (c *Config, err error) {
func decodeBytes(b []byte, system, ignoreMatchDirective bool, depth uint8) (c *Config, err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
Expand All @@ -357,17 +361,18 @@ func decodeBytes(b []byte, system bool, depth uint8) (c *Config, err error) {
}
}()

c = parseSSH(lexSSH(b), system, depth)
c = parseSSH(lexSSH(b), system, ignoreMatchDirective, depth)
return c, err
}

// Config represents an SSH config file.
type Config struct {
// A list of hosts to match against. The file begins with an implicit
// "Host *" declaration matching all hosts.
Hosts []*Host
depth uint8
position Position
Hosts []*Host
depth uint8
position Position
ignoreMatchDirective bool
}

// Get finds the first value in the configuration that matches the alias and
Expand All @@ -388,7 +393,7 @@ func (c *Config) Get(alias, key string) (string, error) {
case *KV:
// "keys are case insensitive" per the spec
lkey := strings.ToLower(t.Key)
if lkey == "match" {
if lkey == "match" && !c.ignoreMatchDirective {
panic("can't handle Match directives")
}
if lkey == lowerKey {
Expand Down Expand Up @@ -711,7 +716,8 @@ func removeDups(arr []string) []string {
// Configuration files are parsed greedily (e.g. as soon as this function runs).
// Any error encountered while parsing nested configuration files will be
// returned.
func NewInclude(directives []string, hasEquals bool, pos Position, comment string, system bool, depth uint8) (*Include, error) {
func NewInclude(directives []string, hasEquals bool, pos Position, comment string, system, ignoreMatchDirective bool, depth uint8,
) (*Include, error) {
if depth > maxRecurseDepth {
return nil, ErrDepthExceeded
}
Expand Down Expand Up @@ -744,7 +750,7 @@ func NewInclude(directives []string, hasEquals bool, pos Position, comment strin
matches = removeDups(matches)
inc.matches = matches
for i := range matches {
config, err := parseWithDepth(matches[i], depth)
config, err := parseWithDepth(matches[i], ignoreMatchDirective, depth)
if err != nil {
return nil, err
}
Expand Down
30 changes: 28 additions & 2 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ var files = []string{
func TestDecode(t *testing.T) {
for _, filename := range files {
data := loadFile(t, filename)
cfg, err := Decode(bytes.NewReader(data))
cfg, err := Decode(bytes.NewReader(data), false)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -339,7 +339,7 @@ func TestIncludeString(t *testing.T) {
if err != nil {
log.Fatal(err)
}
c, err := Decode(bytes.NewReader(data))
c, err := Decode(bytes.NewReader(data), false)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -467,3 +467,29 @@ func TestCustomFinder(t *testing.T) {
t.Errorf("expected to find User root, got %q", val)
}
}

func TestCustomFinderWhenIgnoringMatchDirective(t *testing.T) {
us := &UserSettings{
IgnoreMatchDirective: true,
}
us.ConfigFinder(func() string {
return "testdata/config1-with-match-directive"
})

val := us.Get("git.yahoo.com", "HostName")
if val != "git.proxy.com" {
t.Errorf("expected to find Hostname git.proxy.com, got %q", val)
}
}

func TestCustomFinderWhenNotIgnoringMatchDirective(t *testing.T) {
us := &UserSettings{}
us.ConfigFinder(func() string {
return "testdata/config1-with-match-directive"
})

val := us.Get("git.yahoo.com", "HostName")
if val != "" {
t.Errorf("expected to find Hostname empty %q", val)
}
}
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Host *.example.com
Compression yes
`

cfg, _ := ssh_config.Decode(strings.NewReader(config))
cfg, _ := ssh_config.Decode(strings.NewReader(config), false)
val, _ := cfg.Get("test.example.com", "Compression")
fmt.Println(val)
// Output: yes
Expand Down
36 changes: 20 additions & 16 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import (
)

type sshParser struct {
flow chan token
config *Config
tokensBuffer []token
currentTable []string
seenTableKeys []string
ignoreMatchDirective bool
flow chan token
config *Config
tokensBuffer []token
currentTable []string
seenTableKeys []string
// /etc/ssh parser or local parser - used to find the default for relative
// filepaths in the Include directive
system bool
Expand Down Expand Up @@ -104,7 +105,7 @@ func (p *sshParser) parseKV() sshParserStateFn {
tok = p.getToken()
comment = tok.val
}
if strings.ToLower(key.val) == "match" {
if strings.ToLower(key.val) == "match" && !p.ignoreMatchDirective {
// https://github.com/kevinburke/ssh_config/issues/6
p.raiseErrorf(val, "ssh_config: Match directive parsing is unsupported")
return nil
Expand All @@ -127,18 +128,20 @@ func (p *sshParser) parseKV() sshParserStateFn {
hostval := strings.TrimRightFunc(val.val, unicode.IsSpace)
spaceBeforeComment := val.val[len(hostval):]
val.val = hostval
p.config.ignoreMatchDirective = p.ignoreMatchDirective
p.config.Hosts = append(p.config.Hosts, &Host{
Patterns: patterns,
Nodes: make([]Node, 0),
EOLComment: comment,
spaceBeforeComment: spaceBeforeComment,
hasEquals: hasEquals,
})
},
)
return p.parseStart
}
lastHost := p.config.Hosts[len(p.config.Hosts)-1]
if strings.ToLower(key.val) == "include" {
inc, err := NewInclude(strings.Split(val.val, " "), hasEquals, key.Position, comment, p.system, p.depth+1)
inc, err := NewInclude(strings.Split(val.val, " "), hasEquals, key.Position, comment, p.system, p.ignoreMatchDirective, p.depth+1)
if err == ErrDepthExceeded {
p.raiseError(val, err)
return nil
Expand Down Expand Up @@ -177,7 +180,7 @@ func (p *sshParser) parseComment() sshParserStateFn {
return p.parseStart
}

func parseSSH(flow chan token, system bool, depth uint8) *Config {
func parseSSH(flow chan token, system, ignoreMatchDirective bool, depth uint8) *Config {
// Ensure we consume tokens to completion even if parser exits early
defer func() {
for range flow {
Expand All @@ -187,13 +190,14 @@ func parseSSH(flow chan token, system bool, depth uint8) *Config {
result := newConfig()
result.position = Position{1, 1}
parser := &sshParser{
flow: flow,
config: result,
tokensBuffer: make([]token, 0),
currentTable: make([]string, 0),
seenTableKeys: make([]string, 0),
system: system,
depth: depth,
ignoreMatchDirective: ignoreMatchDirective,
flow: flow,
config: result,
tokensBuffer: make([]token, 0),
currentTable: make([]string, 0),
seenTableKeys: make([]string, 0),
system: system,
depth: depth,
}
parser.run()
return result
Expand Down
2 changes: 1 addition & 1 deletion parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func (b *errReader) Read(p []byte) (n int, err error) {

func TestIOError(t *testing.T) {
buf := &errReader{}
_, err := Decode(buf)
_, err := Decode(buf, false)
if err == nil {
t.Fatal("expected non-nil err, got nil")
}
Expand Down
6 changes: 6 additions & 0 deletions testdata/config1-with-match-directive
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Match all
Include ~/.ssh
Host *
User usr
Host git.yahoo.com
HostName git.proxy.com