/
connection.go
100 lines (83 loc) · 2.2 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package rpc
import (
"bufio"
"encoding/binary"
"fmt"
"net"
)
type Connection struct {
reader *bufio.Reader
writer *bufio.Writer
conn net.Conn
remoteAddress string
localAddress string
context *Context
nextMessageId int
}
type Context struct {
activeConnections map[string]*Connection
}
func NewContext() *Context {
return &Context{make(map[string]*Connection)}
}
func (context *Context) Close() {
for _, connection := range context.activeConnections {
connection.Close()
}
context.activeConnections = make(map[string]*Connection)
}
func OpenConnection(remoteAddress string) (*Connection, error) {
con, err := net.Dial("tcp", remoteAddress)
if err != nil {
return nil, err
}
reader := bufio.NewReader(con)
writer := bufio.NewWriter(con)
localAddress := con.LocalAddr().String()
return &Connection{reader, writer, con, remoteAddress, localAddress, nil, 0}, nil
}
func Wrap(conn net.Conn, context *Context) *Connection {
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
res := &Connection{reader, writer, conn, conn.LocalAddr().String(), conn.RemoteAddr().String(), context, 0}
context.activeConnections[conn.LocalAddr().String()] = res
return res
}
func (cc Connection) Close() error {
if cc.context != nil {
delete(cc.context.activeConnections, cc.conn.LocalAddr().String())
}
if err := cc.conn.Close(); err != nil {
return err
}
return nil
}
func (c *Connection) NextMessageId() int {
ret := c.nextMessageId
c.nextMessageId++
return ret
}
func (c Connection) LocalAddress() string {
return c.localAddress
}
func (c Connection) RemoteAddress() string {
return c.remoteAddress
}
func (cc Connection) Flush() error {
return cc.writer.Flush()
}
func (c Connection) Read(p []byte) (n int, err error) {
return c.reader.Read(p)
}
func (c Connection) Write(p []byte) (n int, err error) {
return c.writer.Write(p)
}
func (c Connection) readByte(data *byte) error {
return binary.Read(c, binary.LittleEndian, data)
}
func (c Connection) readInt32(order binary.ByteOrder, data *int32) error {
return binary.Read(c, order, data)
}
func (c Connection) String() string {
return fmt.Sprintf("connection: %v <-> %v", c.localAddress, c.remoteAddress)
}