-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtransport.go
103 lines (85 loc) · 2.15 KB
/
transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package multibully
import (
"errors"
"log"
"net"
)
type Transport interface {
Read() (*Message, error)
Write(*Message) error
Close() error
}
type MulticastTransport struct {
readConn *net.UDPConn
writeConn *net.UDPConn
buffer []byte
}
func NewMulticastTransport(mcastIP *net.IP, mcastInterface *net.Interface, port int) (*MulticastTransport, error) {
if !mcastIP.IsMulticast() {
return nil, errors.New("Address supplied is not a multicast address")
}
listenIP := *mcastIP
listenAddr := &net.UDPAddr{IP: listenIP, Port: port}
log.Printf("* Listening on: %+v", listenAddr)
readConn, err := net.ListenMulticastUDP("udp", mcastInterface, listenAddr)
if err != nil {
log.Fatal(err)
}
broadcastAddr := &net.UDPAddr{IP: *mcastIP, Port: port}
log.Printf("* Broadcasting on: %+v", broadcastAddr)
writeConn, err := net.DialUDP("udp", nil, broadcastAddr)
if err != nil {
log.Fatal(err)
}
return &MulticastTransport{readConn: readConn, writeConn: writeConn, buffer: []byte{}}, nil
}
func (t *MulticastTransport) Read() (*Message, error) {
readBuffer := make([]byte, 1500)
var msg *Message
var err error
Loop:
for {
num, _, e := t.readConn.ReadFrom(readBuffer)
if err != nil {
log.Println(err)
err = e
}
t.buffer = append(t.buffer, readBuffer[:num]...)
if len(t.buffer) >= msgBlockSize {
data := t.buffer[:msgBlockSize]
msg = NewMessageFromBytes(data)
t.buffer = t.buffer[msgBlockSize:]
break Loop
}
}
return msg, err
}
func (t *MulticastTransport) Write(m *Message) error {
bytes := m.Pack()
_, err := t.writeConn.Write(bytes)
return err
}
func (t *MulticastTransport) Close() error {
if err := t.readConn.Close(); err != nil {
return err
}
if err := t.writeConn.Close(); err != nil {
return err
}
return nil
}
// TODO: this should handle IPv6 addresses
func getLocalInterfaceIPAddress(ifi *net.Interface) (*net.IP, error) {
addrs, err := ifi.Addrs()
if err != nil {
return nil, err
}
for _, add := range addrs {
if ipnet, ok := add.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
return &ipnet.IP, nil
}
}
}
return nil, errors.New("No local interface address found")
}