-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathassert.go
345 lines (297 loc) · 7.32 KB
/
assert.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
package assert
import (
"bytes"
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"github.com/r3labs/diff/v3"
"github.com/sanity-io/litter"
"github.com/sergi/go-diff/diffmatchpatch"
)
// DiffOptions is the configuration for the diff output.
// You can set the options to customize the output.
var DiffOptions = litter.Options{
StripPackageNames: false,
HidePrivateFields: false,
Separator: " ",
}
// testingTB is subset of testing.TB interface for testing purposes.
type testingTB interface {
Helper()
Fatalf(format string, args ...any)
}
// Equal checks if two values are equal.
//
// Following rules are used to determine if two values are equal:
//
// 1. if both values are nil, they are equal.
// 2. if one value is nil and the other is not, they are not equal
// 3. if Equal(v) bool method is defined on the value, it is used.
// 4. if the value is a []byte, bytes.Equal is used.
// 5. otherwise, reflect.DeepEqual is used.
func Equal[V any](t testingTB, got V, want V) {
if _, ok := any(got).(error); ok {
panic("use assert.Error() for errors")
}
t.Helper()
if !areEqual(got, want) {
t.Fatalf("expected equal\n%s", diffValue(got, want))
}
}
// NotEqual checks if two values are not equal.
// See [Equal] for rules used to determine equality.
func NotEqual[T any](t testingTB, got T, want T) {
if _, ok := any(got).(error); ok {
panic("use assert.Error() for errors")
}
t.Helper()
if areEqual(got, want) {
t.Fatalf("expected not equal, but got equal")
}
}
// Error checks if an error is not nil.
func Error(t testingTB, err error) {
t.Helper()
if err == nil {
t.Fatalf("expected error, got nil")
}
}
// NoError checks if an error is nil.
func NoError(t testingTB, err error) {
t.Helper()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
// ErrorContains checks if an error is not nil and contains the target.
//
// Target can be:
//
// 1. string
//
// The string is compiled as a regexp, and the error is matched against it.
// If it is not a valid regexp, it is used as a string to check if the error contains it.
//
// 2. error
//
// The error is checked if it is equal to the target using errors.Is.
//
// 3. type
//
// The error is checked if it can be converted to the target type using errors.As.
func ErrorContains(t testingTB, err error, target any) {
t.Helper()
if err == nil {
t.Fatalf("error is nil")
return
}
// catch any errors.Is/As panics
defer func() {
if r := recover(); r != nil {
t.Fatalf("error.Is/As panic %s", r)
}
}()
switch e := target.(type) {
case string:
// if this is a valid regexp, compile it and use it
// otherwise, just use it as a string
if re, err1 := regexp.Compile(e); err1 == nil {
if !re.MatchString(err.Error()) {
t.Fatalf("unexpected error: %q does not match %q", err, e)
}
} else {
if !strings.Contains(err.Error(), e) {
t.Fatalf("unexpected error: %q does not contain %q", err, e)
}
}
case error:
if !errors.Is(err, e) {
t.Fatalf("unexpected error: %q is not %T", err, e)
}
default:
if !errors.As(err, e) {
t.Fatalf("unexpected error: %q is not %T", err, e)
}
}
}
// Zero checks if got is zero value.
func Zero[T comparable](t testingTB, got T) {
t.Helper()
if got != *new(T) {
t.Fatalf("expected zero, got %v", got)
}
}
// NotZero checks if got is not zero value.
func NotZero[V comparable](t testingTB, got V) {
t.Helper()
if got == *new(V) {
t.Fatalf("expected not zero, got %v", got)
}
}
// Nil checks if got is nil.
func Nil(t testingTB, got any) {
if _, ok := got.(error); ok {
panic("use assert.NoError() for errors")
}
t.Helper()
if !isNil(got) {
t.Fatalf("expected nil, got %v", got)
}
}
// NotNil checks if got is not nil.
func NotNil(t testingTB, got any) {
if _, ok := got.(error); ok {
panic("use assert.Error() for errors")
}
t.Helper()
if isNil(got) {
t.Fatalf("expected not nil, got nil")
}
}
// Len checks if the length of got is l.
// got can be any go type accepted by builtin len function.
func Len[V any](t testingTB, got V, want int) {
t.Helper()
l := reflect.ValueOf(got).Len()
if l != want {
t.Fatalf("expected length %d, got %d", want, l)
}
}
// True checks if got is true.
func True(t testingTB, got bool) {
t.Helper()
if !got {
t.Fatalf("expected true, got false")
}
}
// False checks if got is false.
func False(t testingTB, got bool) {
t.Helper()
if got {
t.Fatalf("expected false, got true")
}
}
// Panic checks if f panics.
func Panic(t testingTB, f func()) {
t.Helper()
defer func() {
t.Helper()
if r := recover(); r == nil {
t.Fatalf("expected panic, got nothing")
}
}()
f()
}
// NotPanic checks if f does not panic.
func NotPanic(t testingTB, f func()) {
t.Helper()
defer func() {
t.Helper()
if r := recover(); r != nil {
t.Fatalf("unexpected panic: %v", r)
}
}()
f()
}
// Defer returns a function that will call fn and check if an error is returned.
func Defer(t testingTB, fn func() error) func() {
t.Helper()
return func() {
if err := fn(); err != nil {
t.Fatalf("unexpected defer error: %v", err)
}
}
}
// TypeAssert checks if got is of type V and returns it.
func TypeAssert[V any](t testingTB, got any) V {
t.Helper()
v, ok := got.(V)
if !ok {
t.Fatalf("assertion %T.(%T) failed", v, got)
}
return v
}
func areEqual[V any](got V, want V) bool {
if isNil(got) && isNil(want) {
return true
}
if isNil(got) || isNil(want) {
return false
}
if g, ok := any(got).(interface{ Equal(V) bool }); ok {
return g.Equal(want)
}
if g, ok := any(got).([]byte); ok {
return bytes.Equal(g, any(want).([]byte))
}
return reflect.DeepEqual(deref(got), deref(want))
}
func deref(a any) any {
v := reflect.ValueOf(a)
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
return v.Interface()
}
func isNil(obj any) bool {
if obj == nil {
return true
}
v := reflect.ValueOf(obj)
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Map,
reflect.Pointer, reflect.UnsafePointer, reflect.Interface,
reflect.Slice:
return v.IsNil()
}
return false
}
func diffValue[V any](a V, b V) string {
// first let GoStringer format the values if they implement it
if _, ok := any(a).(fmt.GoStringer); ok {
return diffGoStringer(any(a).(fmt.GoStringer), any(b).(fmt.GoStringer))
}
// use litter to dump the values and then diff them
// but if there's no difference, then try next method
as := DiffOptions.Sdump(a)
bs := DiffOptions.Sdump(b)
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(bs, as, true)
allDiffEqual := true
for _, d := range diffs {
if d.Type != diffmatchpatch.DiffEqual {
allDiffEqual = false
break
}
}
if !allDiffEqual {
return dmp.DiffPrettyText(diffs)
}
// if litter fails, then use diff package and show the changes
if changelog, err := diff.Diff(a, b); err == nil {
ret := "\n"
for _, c := range changelog {
ret += fmt.Sprintf("[%s]%T path %s: %q -> %q\n", c.Type, a, strings.Join(c.Path, "."), c.From, c.To)
}
return ret
}
// if all fails, then just show the GoString of the values
aStr := fmt.Sprintf("%#v", a)
aStr = aStr[0:min(len(aStr), 1024)]
bStr := fmt.Sprintf("%#v", b)
bStr = bStr[0:min(len(bStr), 1024)]
return fmt.Sprintf(" got: %s\nwant: %s", aStr, bStr)
}
func diffGoStringer(a, b fmt.GoStringer) string {
got := "<nil>"
if !isNil(a) {
got = a.GoString()
}
want := "<nil>"
if !isNil(b) {
want = b.GoString()
}
return fmt.Sprintf("got %s, want %s", got, want)
}