This repository has been archived by the owner on Jul 2, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsnmp.go
454 lines (401 loc) · 12 KB
/
snmp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
package snmp
import (
"context"
"fmt"
"math"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/k-sone/snmpgo"
"golang.org/x/xerrors"
)
type ErrMalformedResponse struct {
ExpOIDs int
GotOIDs int
}
func (e *ErrMalformedResponse) Error() string {
return fmt.Sprintf("Expected %d OIDs in response, got %d", e.ExpOIDs, e.GotOIDs)
}
// MessageRequest represents a single SNMP request (with some special-case
// conditions for bulk walks). The response may be sent either to the .Response
// or .C attributes.
type MessageRequest struct {
Addr net.Addr // Destination address. Should be *net.UDPAddr unless testing
Message *Message // Source message
DontRetryOnError bool // Disable (if true) retry of SNMP v1-type errors
// Response is called when a response comes in. If not nil, it will be
// called on the main thread. If it blocks, then the SNMP operations will
// block (so be fast!).
Response func(response MessageResponse)
// C is a channel created by the user. If it is not nil, it messages will be
// sent to it. If it blocks then SNMP operations will block (so be fast!)
C chan MessageResponse
// timer stores the action on timeout for a single SNMP packet as determined
// by timeoutAfter, on timeout, either retries or removes PDU requestId from
// MessageSender.active map, depending on attempts
timer *time.Timer
attempts int // used for retry and timeout tracking that ultimately sends ErrTimedOut responses
timeoutAfter time.Duration // period to wait before resending SNMP packet
}
// Retries is called when you want to modify the default retry and timeout behavior of
// individual SNMP requests. If not called, defaults to being called as:
// mr.Retries(2, time.Second)
func (mr *MessageRequest) Retries(retries int, timeoutAfter time.Duration) {
mr.attempts = retries + 1
mr.timeoutAfter = timeoutAfter
}
func (mr *MessageRequest) send(ctx context.Context, inline bool, pRes MessageResponse) {
if mr.Response != nil {
if inline {
mr.Response(pRes)
} else {
go mr.Response(pRes)
}
}
if mr.C != nil {
select {
case <-ctx.Done():
case mr.C <- pRes:
}
}
}
func (mr *MessageRequest) setRetryCallback(ms *MessageSender) {
if mr.attempts == 0 {
panic("internal error")
}
mr.attempts--
mr.timer = time.AfterFunc(mr.timeoutAfter, func() {
// Quick check if not shut down
if ms.ctx.Err() != nil {
return
}
rid := int32(mr.Message.Pdu.RequestId())
ms.activeMu.Lock()
_, ok := ms.active[rid]
if ok {
delete(ms.active, rid)
}
ms.activeMu.Unlock()
if !ok {
// Race condition. If it did not exist in ps.active, then it was
// removed for a legit reason
return
}
if mr.attempts == 0 {
mr.send(ms.ctx, ms.cbInline, MessageResponse{
Request: mr.Message,
Err: ErrTimedOut,
})
return
}
select {
case <-ms.ctx.Done():
case ms.MC <- mr:
}
})
}
// MessageResponse is sent to the MessageRequest.Response or MessageRequest.C attrs. Either
// Response or Err will be non-nil, not both/neither.
type MessageResponse struct {
Request *Message // the original request
Response *Message // full response. If non-nil, then Err is nil
Err error // may be either ErrTimedOut or ErrWalkSingleOid
}
// MessageSender is the controller for all SNMP operations. SNMP message requests are sent to
type MessageSender struct {
MC chan *MessageRequest
TC chan *TableRequest
// Reply that do not match known requests (perhaps the request timed out internally)
OrphanedReplies uint64
// a valid RequestID was seen in a reply from host B whereas the request was sent to host A
IllegalHostReplies uint64
conn net.PacketConn
ctx context.Context
cancel context.CancelFunc
onErr ErrorLogger
cbInline bool
active map[int32]*MessageRequest
activeMu sync.Mutex
wg sync.WaitGroup
}
func NewMessageSenderOpts() *MessageSenderOpts {
msg := &MessageSenderOpts{}
return msg.init()
}
type MessageSenderOpts struct {
_init bool
chanSize int
conn net.PacketConn
onErr ErrorLogger
cbInline bool
}
func (o *MessageSenderOpts) init() *MessageSenderOpts {
if !o._init {
o._init = true
o.chanSize = DefaultChanSize
o.conn = nil
o.onErr = DefaultErrorLogger
o.cbInline = false
}
return o
}
// ChanSize (which defaults to DefaultChanSize) controls the queue size of
// MessageSender.MC and MessageSender.TC. Setting this to 0 will cause
// the queue to become blocking.
func (o *MessageSenderOpts) ChanSize(sz int) *MessageSenderOpts {
o.chanSize = sz
return o
}
// Conn is the socket on which all UDP packets will be sent. By default,
// it binds to UDP ':0', which will create a UDP socket on an ephemeral port
// which will be in the range (on Linux) as per this command:
// sysctl net.ipv4.ip_local_port_range
func (o *MessageSenderOpts) Conn(conn net.PacketConn) *MessageSenderOpts {
o.conn = conn
return o
}
// OnErr provides an interface for logging all errors. By default it uses the
// 'log' package. If you wish to provide a code callback, use OnErrFunc instead
func (o *MessageSenderOpts) OnErr(onErr ErrorLogger) *MessageSenderOpts {
o.onErr = onErr
return o
}
// OnErrFunc provides an interface for logging all errors. By default it uses the
// 'log' package. This method allows a func instead of interface
func (o *MessageSenderOpts) OnErrFunc(onErr func(error)) *MessageSenderOpts {
o.onErr = logErrFunc(onErr)
return o
}
// CallbackInline controls whether each MessageRequest.Response Callback is
// called in its own goroutine, or on the main goroutine. If you are leaving
// the .Response attr nil and only using channels then this is irrelevant.
//
// By default this is false, each callback will run on its goroutine.
func (o *MessageSenderOpts) CallbackInline(inline bool) *MessageSenderOpts {
o.cbInline = inline
return o
}
type logErrFunc func(error)
func (e logErrFunc) Log(err error) {
e(err)
}
// NewMessageSender sets up *MessageSender to allow sending and receiving of SNMP messages
// asynchronously from a socket.
//
// Param ctx will gracefully shut down all goroutines created by this method call.
// Calling MessageSender.Wait() will will block until all goroutines have exited.
//
// If 'opts' is nil, sensible tunable defaults will be used however internal error messages
// will be sent using the 'log' interface
func NewMessageSender(ctx context.Context, opts *MessageSenderOpts) (*MessageSender, error) {
if opts == nil {
opts = NewMessageSenderOpts()
}
opts.init()
conn := opts.conn
if conn == nil {
var err error
conn, err = net.ListenPacket("udp", ":0")
if err != nil {
return nil, err
}
}
ctx, cancel := context.WithCancel(ctx)
ms := &MessageSender{
conn: conn,
ctx: ctx,
cancel: cancel,
active: make(map[int32]*MessageRequest),
onErr: opts.onErr,
cbInline: opts.cbInline,
MC: make(chan *MessageRequest, opts.chanSize),
TC: make(chan *TableRequest, opts.chanSize),
}
ms.wg.Add(4)
go ms.deadlineOnCancel()
go ms.messageChanListener()
go ms.tableChanListener()
go ms.onRecv()
return ms, nil
}
// Wait blocks until all goroutines created by NewMessageSenderWithConn have
// shut down. That shutdown will happen when the context provided to it closes.
func (ms *MessageSender) Wait() { ms.wg.Wait() }
func (ms *MessageSender) deadlineOnCancel() {
defer ms.wg.Done()
<-ms.ctx.Done()
ms.conn.SetDeadline(time.Now())
}
func (ms *MessageSender) messageChanListener() {
defer ms.wg.Done()
reqID := rand.Int31()
for {
select {
case <-ms.ctx.Done():
return
case req := <-ms.MC:
if req.attempts == 0 {
req.Retries(defaultRetry, defaultRetryAfter)
}
if reqID == math.MaxInt32 {
reqID = 0
} else {
reqID++
}
req.Message.Pdu.SetRequestId(int(reqID))
buf, err := req.Message.Marshal()
if err != nil {
req.send(ms.ctx, ms.cbInline, MessageResponse{
Request: req.Message,
Err: err,
})
continue
}
ms.activeMu.Lock()
ms.active[reqID] = req
ms.activeMu.Unlock()
req.setRetryCallback(ms)
if _, err := ms.conn.WriteTo(buf, req.Addr); err != nil {
req.send(ms.ctx, ms.cbInline, MessageResponse{
Request: req.Message,
Err: err,
})
ms.activeMu.Lock()
delete(ms.active, reqID)
ms.activeMu.Unlock()
req.timer.Stop()
}
}
}
}
func (ms *MessageSender) onRecv() {
defer ms.wg.Done()
var buf [65536]byte
for {
n, addr, err := ms.conn.ReadFrom(buf[:])
if err != nil {
// Most likely because err.(net.Error).Timeout() == true
// At this point not much we care about.
if ms.ctx.Err() == nil {
ms.cancel()
ms.onErr.Log(xerrors.Errorf("error reading from socket: %w", err))
}
return
}
res := &Message{}
if _, err := res.Unmarshal(buf[:n]); err != nil {
ms.onErr.Log(xerrors.Errorf("[%s] unable to parse inbound SNMP packet of len %d: %w", addr, n, err))
continue
}
req := ms.findReq(addr, res)
if req == nil {
continue
}
pt := req.Message.Pdu.PduType()
if (pt == snmpgo.GetRequest || pt == snmpgo.GetNextRequest) &&
len(req.Message.Pdu.VarBinds()) != len(res.Pdu.VarBinds()) {
// for _, vb := range req.Message.Pdu.VarBinds() {
// fmt.Printf("REQ: %s\n", vb.Oid)
// }
// for _, vb := range res.Pdu.VarBinds() {
// fmt.Printf("RES: %s\n", vb.Oid)
// }
req.send(ms.ctx, ms.cbInline, MessageResponse{
Request: req.Message,
Err: &ErrMalformedResponse{
ExpOIDs: len(req.Message.Pdu.VarBinds()),
GotOIDs: len(res.Pdu.VarBinds()),
},
})
continue
}
if res.Pdu.ErrorStatus() == snmpgo.NoSuchName {
if !req.DontRetryOnError {
if len(res.Pdu.VarBinds()) > 1 {
ms.queueResendWithOmissions(req, res)
continue
}
// Could just as easily be NoSuchObject as well, or I could leave it blank. Unsure.
res.Pdu.VarBinds()[0].Variable = snmpgo.NewNoSucheInstance()
}
}
req.send(ms.ctx, ms.cbInline, MessageResponse{
Request: req.Message,
Response: res,
})
}
}
func (ms *MessageSender) queueResendWithOmissions(req *MessageRequest, res *Message) {
mr := &MessageRequest{
Addr: req.Addr,
attempts: req.attempts,
timeoutAfter: req.timeoutAfter,
Message: &Message{
Community: req.Message.Community,
Version: req.Message.Version,
Pdu: snmpgo.NewPdu(req.Message.Version, req.Message.Pdu.PduType()),
},
}
errIdx := res.Pdu.ErrorIndex()
errVb := req.Message.Pdu.VarBinds()[errIdx]
for i, vb := range req.Message.Pdu.VarBinds() {
if errIdx == i {
continue
}
mr.Message.Pdu.AppendVarBind(vb.Oid, snmpgo.NewNull())
}
mr.Response = func(mr MessageResponse) {
if mr.Err != nil {
req.send(ms.ctx, ms.cbInline, mr)
}
newPdu := snmpgo.NewPdu(mr.Response.Version, mr.Response.Pdu.PduType())
newPdu.SetRequestId(mr.Response.Pdu.RequestId())
for i, vb := range mr.Response.Pdu.VarBinds() {
newPdu.AppendVarBind(vb.Oid, vb.Variable)
if errIdx-1 == i {
newPdu.AppendVarBind(errVb.Oid, snmpgo.NewNoSucheInstance())
}
}
mr.Response.Pdu = newPdu
req.send(ms.ctx, ms.cbInline, MessageResponse{
Request: req.Message,
Response: mr.Response,
Err: mr.Err,
})
}
ms.MC <- mr
}
func (ms *MessageSender) findReq(addr net.Addr, res *Message) *MessageRequest {
ms.activeMu.Lock()
defer ms.activeMu.Unlock()
//fmt.Printf("findReq(): looking for %d in tab size %d\n", res.Pdu.RequestId(), len(ms.active))
req, ok := ms.active[int32(res.Pdu.RequestId())]
if !ok {
atomic.AddUint64(&ms.OrphanedReplies, 1)
return nil
}
if a1, ok := addr.(*net.UDPAddr); !ok {
if a2, ok := req.Addr.(*net.UDPAddr); ok {
if a1.IP.Equal(a2.IP) && a1.Port == a2.Port && a1.Zone == a2.Zone {
goto equal
}
atomic.AddUint64(&ms.IllegalHostReplies, 1)
ms.onErr.Log(xerrors.Errorf("[%s] requestId to UDP address, wanted %s", addr, req.Addr))
return nil
}
}
if addr.String() != req.Addr.String() {
atomic.AddUint64(&ms.IllegalHostReplies, 1)
ms.onErr.Log(xerrors.Errorf("[%s] requestId to address, wanted %s", addr.String(), req.Addr))
return nil
}
equal:
if req.timer != nil {
req.timer.Stop()
req.timer = nil
}
delete(ms.active, int32(res.Pdu.RequestId()))
return req
}