-
Notifications
You must be signed in to change notification settings - Fork 2
/
strategy.go
130 lines (106 loc) · 2.9 KB
/
strategy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package radium
import (
"context"
"sync"
)
// Strategy implementation is responsible for performing queries
// against given set of sources using a particular approach.
type Strategy interface {
Execute(ctx context.Context, query Query, sources []RegisteredSource) ([]Article, error)
}
// NewConcurrent initializes a concurrent radium strategy
func NewConcurrent(logger Logger) *Concurrent {
return &Concurrent{
Logger: logger,
}
}
// NewNthResult initializes NthResult strategy with given n
func NewNthResult(n int, logger Logger) *NthResult {
return &NthResult{stopAt: n, Logger: logger}
}
// Concurrent is a radium strategy implementation.
type Concurrent struct {
Logger
}
// Execute the query against given list of sources concurrently. This strategy
// ignores the source errors and simply logs them.
func (con Concurrent) Execute(ctx context.Context, query Query, sources []RegisteredSource) ([]Article, error) {
results := newSafeResults()
wg := &sync.WaitGroup{}
for _, source := range sources {
select {
case <-ctx.Done():
con.Infof("received cancel signal. stopping")
break
default:
}
wg.Add(1)
go func(wg *sync.WaitGroup, src RegisteredSource, rs *safeResults) {
defer wg.Done()
srcResults, err := src.Search(ctx, query)
if err != nil {
con.Warnf("source '%s' failed: %s", src.Name, err)
return
}
rs.extend(srcResults, src.Name, con.Logger)
}(wg, source, results)
}
wg.Wait()
return results.results, nil
}
// NthResult implements a radium search strategy. This strategy
// executes search in the given order of sources and stops at nth
// result or if all the sources are executed.
type NthResult struct {
Logger
stopAt int
}
// Execute each source in srcs until n results are obtained or all sources have
// been executed. This strategy returns on first error.
func (nth *NthResult) Execute(ctx context.Context, query Query, srcs []RegisteredSource) ([]Article, error) {
results := []Article{}
for _, src := range srcs {
select {
case <-ctx.Done():
break
default:
}
srcResults, err := src.Search(ctx, query)
if err != nil {
return nil, err
}
for _, res := range srcResults {
if err := res.Validate(); err != nil {
nth.Warnf("ignoring invalid result from '%s': %s", src.Name, err)
continue
}
res.Source = src.Name
results = append(results, res)
}
if len(results) >= nth.stopAt {
break
}
}
return results, nil
}
func newSafeResults() *safeResults {
return &safeResults{
mu: &sync.Mutex{},
}
}
type safeResults struct {
mu *sync.Mutex
results []Article
}
func (sr *safeResults) extend(results []Article, srcName string, logger Logger) {
sr.mu.Lock()
defer sr.mu.Unlock()
for _, res := range results {
if err := res.Validate(); err != nil {
logger.Warnf("ignoring invalid result from source '%s': %s", srcName, err)
continue
}
res.Source = srcName
sr.results = append(sr.results, res)
}
}