forked from dagger/dagger
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcachemap.go
123 lines (101 loc) · 2.49 KB
/
cachemap.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package dagql
import (
"context"
"fmt"
"sync"
"github.com/opencontainers/go-digest"
)
type CacheMap[K comparable, T any] interface {
GetOrInitialize(context.Context, K, func(context.Context) (T, error)) (T, bool, error)
Get(context.Context, K) (T, error)
Keys() []K
}
type cacheMap[K comparable, T any] struct {
l sync.Mutex
calls map[K]*cache[T]
}
type cache[T any] struct {
wg sync.WaitGroup
val T
err error
}
// NewCache creates a new cache map suitable for assigning on a Server or
// multiple Servers.
func NewCache() Cache {
return newCacheMap[digest.Digest, Typed]()
}
func NewCacheMap[K comparable, T any]() CacheMap[K, T] {
return newCacheMap[K, T]()
}
func newCacheMap[K comparable, T any]() *cacheMap[K, T] {
return &cacheMap[K, T]{
calls: map[K]*cache[T]{},
}
}
type cacheMapContextKey[K comparable, T any] struct {
key K
m *cacheMap[K, T]
}
var ErrCacheMapRecursiveCall = fmt.Errorf("recursive call detected")
func (m *cacheMap[K, T]) Set(key K, val T) {
m.l.Lock()
m.calls[key] = &cache[T]{
val: val,
}
m.l.Unlock()
}
func (m *cacheMap[K, T]) GetOrInitialize(ctx context.Context, key K, fn func(ctx context.Context) (T, error)) (T, bool, error) {
return m.GetOrInitializeOnHit(ctx, key, fn, func(T, error) {})
}
func (m *cacheMap[K, T]) GetOrInitializeOnHit(ctx context.Context, key K, fn func(ctx context.Context) (T, error), onHit func(T, error)) (T, bool, error) {
if v := ctx.Value(cacheMapContextKey[K, T]{key: key, m: m}); v != nil {
var zero T
return zero, false, ErrCacheMapRecursiveCall
}
m.l.Lock()
if c, ok := m.calls[key]; ok {
m.l.Unlock()
c.wg.Wait()
if onHit != nil {
onHit(c.val, c.err)
}
return c.val, true, c.err
}
c := &cache[T]{}
c.wg.Add(1)
m.calls[key] = c
m.l.Unlock()
ctx = context.WithValue(ctx, cacheMapContextKey[K, T]{key: key, m: m}, struct{}{})
c.val, c.err = fn(ctx)
c.wg.Done()
if c.err != nil {
m.l.Lock()
delete(m.calls, key)
m.l.Unlock()
}
return c.val, false, c.err
}
func (m *cacheMap[K, T]) Get(ctx context.Context, key K) (T, error) {
if v := ctx.Value(cacheMapContextKey[K, T]{key: key, m: m}); v != nil {
var zero T
return zero, ErrCacheMapRecursiveCall
}
m.l.Lock()
if c, ok := m.calls[key]; ok {
m.l.Unlock()
c.wg.Wait()
return c.val, c.err
}
m.l.Unlock()
var zero T
return zero, fmt.Errorf("key not found")
}
func (m *cacheMap[K, T]) Keys() []K {
m.l.Lock()
keys := make([]K, 0, len(m.calls))
for k := range m.calls {
keys = append(keys, k)
}
m.l.Unlock()
return keys
}