/
utils.go
106 lines (97 loc) · 2.37 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package dicoprotos
import (
"bytes"
"encoding/binary"
"io"
"net"
"github.com/golang/protobuf/proto"
)
func GetMessageType(m proto.Message) SelfDescribingMessage_MessageType {
switch m.(type) {
case *Handshake:
return SelfDescribingMessage_HANDSHAKE
case *DoTask:
return SelfDescribingMessage_DO_TASK
case *TaskStatus:
return SelfDescribingMessage_TASK_STATUS
case *TaskResult:
return SelfDescribingMessage_TASK_RESULT
case *SubmitTask:
return SelfDescribingMessage_SUBMIT_TASK
case *SubmitCode:
return SelfDescribingMessage_SUBMIT_CODE
}
panic("invalid msg type")
}
func GetSDMMessageType(msg *SelfDescribingMessage) proto.Message {
switch msg.GetType() {
case SelfDescribingMessage_HANDSHAKE:
return new(Handshake)
case SelfDescribingMessage_DO_TASK:
return new(DoTask)
case SelfDescribingMessage_TASK_STATUS:
return new(TaskStatus)
case SelfDescribingMessage_TASK_RESULT:
return new(TaskResult)
case SelfDescribingMessage_SUBMIT_TASK:
return new(SubmitTask)
case SelfDescribingMessage_SUBMIT_CODE:
return new(SubmitCode)
}
panic("invalid msg type")
}
func DecodeUnknownMessage(buf []byte) (proto.Message, error) {
sdm := new(SelfDescribingMessage)
err := proto.Unmarshal(buf, sdm)
if err != nil {
return nil, err
}
msg := GetSDMMessageType(sdm)
err = proto.Unmarshal(sdm.GetData(), msg)
return msg, err
}
func WrapMessage(msg proto.Message) proto.Message {
msgtype := GetMessageType(msg)
data, err := proto.Marshal(msg)
if err != nil {
panic(err)
}
sdm := &SelfDescribingMessage{
Type: &msgtype,
Data: data,
}
return sdm
}
func ReadPacket(conn net.Conn) (packet []byte, err error) {
buffer := new(bytes.Buffer)
cpyMin := func(amount int) (err error) {
if buffer.Len() < amount {
_, err = io.CopyN(buffer, conn, int64(amount-buffer.Len()))
}
return err
}
err = cpyMin(4)
if err != nil {
return nil, err
}
headerBuffer := buffer.Next(4)
length := binary.BigEndian.Uint32(headerBuffer)
//fmt.Println("decoded packet length", length, int(length))
err = cpyMin(int(length))
if err != nil {
return nil, err
}
buff := buffer.Next(int(length))
return buff, nil
}
func WritePacket(conn net.Conn, buff []byte) (err error) {
length := len(buff)
header := make([]byte, 4)
binary.BigEndian.PutUint32(header, uint32(length))
_, err = conn.Write(header)
if err != nil {
return err
}
_, err = conn.Write(buff)
return err
}