diff --git a/helper.go b/helper.go index b59a50b..e704554 100644 --- a/helper.go +++ b/helper.go @@ -219,6 +219,9 @@ func formatValueInline(v interface{}) string { string: return formatValue(v) + case *CallbackMatch: + return formatValue(m.matcher()) + case Matcher: return fmt.Sprintf("%T(%q)", v, m.Expected()) @@ -232,12 +235,15 @@ func formatType(v interface{}) string { return "" } - switch v.(type) { + switch m := v.(type) { case *ExactMatch, []byte, string: return "" + case *CallbackMatch: + return formatType(m.matcher()) + default: return fmt.Sprintf(" using %T", v) } @@ -249,6 +255,9 @@ func formatValue(v interface{}) string { } switch m := v.(type) { + case *CallbackMatch: + return formatValue(m.matcher()) + case Matcher: return m.Expected() diff --git a/helper_test.go b/helper_test.go index 0fce67f..de4fe56 100644 --- a/helper_test.go +++ b/helper_test.go @@ -106,6 +106,13 @@ func TestFormatValueInline(t *testing.T) { value: "expected", expected: "expected", }, + { + scenario: "Callback", + value: Match(func() Matcher { + return Exact("expected") + }), + expected: "expected", + }, { scenario: "Matcher", value: JSON("{}"), @@ -158,6 +165,13 @@ func TestFormatType(t *testing.T) { value: "expected", expected: "", }, + { + scenario: "Callback", + value: Match(func() Matcher { + return Exact("expected") + }), + expected: "", + }, { scenario: "Matcher", value: JSON("{}"), @@ -197,6 +211,13 @@ func TestFormatValue(t *testing.T) { value: "expected", expected: "expected", }, + { + scenario: "Callback", + value: Match(func() Matcher { + return Exact("expected") + }), + expected: "expected", + }, { scenario: "ExactMatch", value: Exact("expected"), diff --git a/matcher.go b/matcher.go index 26b6f4d..f4236ec 100644 --- a/matcher.go +++ b/matcher.go @@ -62,25 +62,32 @@ func (m *RegexMatch) Match(actual string) bool { // CallbackMatch matches by calling a function. type CallbackMatch struct { - expect func() string - match func(actual string) bool + callback func() Matcher + upstream Matcher +} + +func (m *CallbackMatch) matcher() Matcher { + if m.upstream == nil { + m.upstream = m.callback() + } + + return m.upstream } // Expected returns the expectation. func (m *CallbackMatch) Expected() string { - return m.expect() + return m.matcher().Expected() } // Match determines if the actual is expected. func (m *CallbackMatch) Match(actual string) bool { - return m.match(actual) + return m.matcher().Match(actual) } // Match creates a callback matcher. -func Match(expect func() string, match func(actual string) bool) Matcher { +func Match(callback func() Matcher) Matcher { return &CallbackMatch{ - expect: expect, - match: match, + callback: callback, } } @@ -110,6 +117,9 @@ func ValueMatcher(v interface{}) Matcher { case Matcher: return val + case func() Matcher: + return Match(val) + case []byte: return Exact(string(val)) diff --git a/matcher_test.go b/matcher_test.go index 3b407a7..fd77376 100644 --- a/matcher_test.go +++ b/matcher_test.go @@ -170,22 +170,6 @@ func TestRegexMatch_Match(t *testing.T) { } } -func TestCallbackMatch(t *testing.T) { - t.Parallel() - - m := Match( - func() string { - return "expected" - }, - func(string) bool { - return false - }, - ) - - assert.Equal(t, "expected", m.Expected()) - assert.False(t, m.Match("actual")) -} - func TestValueMatcher(t *testing.T) { t.Parallel() @@ -231,6 +215,18 @@ func TestValueMatcher(t *testing.T) { } } +func TestValueMatcher_Match(t *testing.T) { + t.Parallel() + + m := ValueMatcher(func() Matcher { + return Exact("expected") + }) + + assert.Equal(t, "expected", m.Expected()) + assert.True(t, m.Match("expected")) + assert.False(t, m.Match("Mismatch")) +} + func TestValueMatcher_Panic(t *testing.T) { t.Parallel()