summaryrefslogtreecommitdiffstats
path: root/lib/relay/protocol/protocol.go
blob: 0bc079ab6574443cd7b85223c998601d5810863e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).

package protocol

import (
	"errors"
	"fmt"
	"io"
)

const (
	magic        = 0x9E79BC40
	ProtocolName = "bep-relay"
)

var (
	ResponseSuccess           = Response{0, "success"}
	ResponseNotFound          = Response{1, "not found"}
	ResponseAlreadyConnected  = Response{2, "already connected"}
	ResponseUnexpectedMessage = Response{100, "unexpected message"}
)

func WriteMessage(w io.Writer, message interface{}) error {
	header := header{
		magic: magic,
	}

	var payload []byte
	var err error

	switch msg := message.(type) {
	case Ping:
		payload, err = msg.MarshalXDR()
		header.messageType = messageTypePing
	case Pong:
		payload, err = msg.MarshalXDR()
		header.messageType = messageTypePong
	case JoinRelayRequest:
		payload, err = msg.MarshalXDR()
		header.messageType = messageTypeJoinRelayRequest
	case JoinSessionRequest:
		payload, err = msg.MarshalXDR()
		header.messageType = messageTypeJoinSessionRequest
	case Response:
		payload, err = msg.MarshalXDR()
		header.messageType = messageTypeResponse
	case ConnectRequest:
		payload, err = msg.MarshalXDR()
		header.messageType = messageTypeConnectRequest
	case SessionInvitation:
		payload, err = msg.MarshalXDR()
		header.messageType = messageTypeSessionInvitation
	case RelayFull:
		payload, err = msg.MarshalXDR()
		header.messageType = messageTypeRelayFull
	default:
		err = errors.New("unknown message type")
	}

	if err != nil {
		return err
	}

	header.messageLength = int32(len(payload))

	headerpayload, err := header.MarshalXDR()
	if err != nil {
		return err
	}

	_, err = w.Write(append(headerpayload, payload...))
	return err
}

func ReadMessage(r io.Reader) (interface{}, error) {
	var header header

	buf := make([]byte, header.XDRSize())
	if _, err := io.ReadFull(r, buf); err != nil {
		return nil, err
	}

	if err := header.UnmarshalXDR(buf); err != nil {
		return nil, err
	}

	if header.magic != magic {
		return nil, errors.New("magic mismatch")
	}
	if header.messageLength < 0 || header.messageLength > 1024 {
		return nil, fmt.Errorf("bad length (%d)", header.messageLength)
	}

	buf = make([]byte, int(header.messageLength))
	if _, err := io.ReadFull(r, buf); err != nil {
		return nil, err
	}

	switch header.messageType {
	case messageTypePing:
		var msg Ping
		err := msg.UnmarshalXDR(buf)
		return msg, err
	case messageTypePong:
		var msg Pong
		err := msg.UnmarshalXDR(buf)
		return msg, err
	case messageTypeJoinRelayRequest:
		var msg JoinRelayRequest
		err := msg.UnmarshalXDR(buf)
		return msg, err
	case messageTypeJoinSessionRequest:
		var msg JoinSessionRequest
		err := msg.UnmarshalXDR(buf)
		return msg, err
	case messageTypeResponse:
		var msg Response
		err := msg.UnmarshalXDR(buf)
		return msg, err
	case messageTypeConnectRequest:
		var msg ConnectRequest
		err := msg.UnmarshalXDR(buf)
		return msg, err
	case messageTypeSessionInvitation:
		var msg SessionInvitation
		err := msg.UnmarshalXDR(buf)
		return msg, err
	case messageTypeRelayFull:
		var msg RelayFull
		err := msg.UnmarshalXDR(buf)
		return msg, err
	}

	return nil, errors.New("unknown message type")
}