summaryrefslogtreecommitdiffstats
path: root/xdr
diff options
context:
space:
mode:
authorJakob Borg <jakob@nym.se>2014-02-15 12:08:55 +0100
committerJakob Borg <jakob@nym.se>2014-02-15 12:08:55 +0100
commitf89fa6caedc491bb4b735633e04b2b52a023fc51 (patch)
treef6f558adea265bc0144021681977619d6b2e4c13 /xdr
parent21a7f3960a6c711c63edf2bb35fc4cc707100394 (diff)
Factor out XDR en/decoding
Diffstat (limited to 'xdr')
-rw-r--r--xdr/reader.go65
-rw-r--r--xdr/writer.go95
-rw-r--r--xdr/xdr_test.go57
3 files changed, 217 insertions, 0 deletions
diff --git a/xdr/reader.go b/xdr/reader.go
new file mode 100644
index 0000000000..b5c39a6b2a
--- /dev/null
+++ b/xdr/reader.go
@@ -0,0 +1,65 @@
+package xdr
+
+import "io"
+
+type Reader struct {
+ r io.Reader
+ tot uint64
+ err error
+ b [8]byte
+}
+
+func NewReader(r io.Reader) *Reader {
+ return &Reader{
+ r: r,
+ }
+}
+
+func (r *Reader) ReadString() string {
+ return string(r.ReadBytes(nil))
+}
+
+func (r *Reader) ReadBytes(dst []byte) []byte {
+ if r.err != nil {
+ return nil
+ }
+ l := int(r.ReadUint32())
+ if r.err != nil {
+ return nil
+ }
+ if l+pad(l) > len(dst) {
+ dst = make([]byte, l+pad(l))
+ } else {
+ dst = dst[:l+pad(l)]
+ }
+ _, r.err = io.ReadFull(r.r, dst)
+ r.tot += uint64(l + pad(l))
+ return dst[:l]
+}
+
+func (r *Reader) ReadUint32() uint32 {
+ if r.err != nil {
+ return 0
+ }
+ _, r.err = io.ReadFull(r.r, r.b[:4])
+ r.tot += 8
+ return uint32(r.b[3]) | uint32(r.b[2])<<8 | uint32(r.b[1])<<16 | uint32(r.b[0])<<24
+}
+
+func (r *Reader) ReadUint64() uint64 {
+ if r.err != nil {
+ return 0
+ }
+ _, r.err = io.ReadFull(r.r, r.b[:8])
+ r.tot += 8
+ return uint64(r.b[7]) | uint64(r.b[6])<<8 | uint64(r.b[5])<<16 | uint64(r.b[4])<<24 |
+ uint64(r.b[3])<<32 | uint64(r.b[2])<<40 | uint64(r.b[1])<<48 | uint64(r.b[0])<<56
+}
+
+func (r *Reader) Tot() uint64 {
+ return r.tot
+}
+
+func (r *Reader) Err() error {
+ return r.err
+}
diff --git a/xdr/writer.go b/xdr/writer.go
new file mode 100644
index 0000000000..30c7c5639c
--- /dev/null
+++ b/xdr/writer.go
@@ -0,0 +1,95 @@
+package xdr
+
+import "io"
+
+func pad(l int) int {
+ d := l % 4
+ if d == 0 {
+ return 0
+ }
+ return 4 - d
+}
+
+var padBytes = []byte{0, 0, 0}
+
+type Writer struct {
+ w io.Writer
+ tot uint64
+ err error
+ b [8]byte
+}
+
+func NewWriter(w io.Writer) *Writer {
+ return &Writer{
+ w: w,
+ }
+}
+
+func (w *Writer) WriteString(s string) (int, error) {
+ return w.WriteBytes([]byte(s))
+}
+
+func (w *Writer) WriteBytes(bs []byte) (int, error) {
+ if w.err != nil {
+ return 0, w.err
+ }
+
+ w.WriteUint32(uint32(len(bs)))
+ if w.err != nil {
+ return 0, w.err
+ }
+
+ var l, n int
+ n, w.err = w.w.Write(bs)
+ l += n
+
+ if p := pad(len(bs)); w.err == nil && p > 0 {
+ n, w.err = w.w.Write(padBytes[:p])
+ l += n
+ }
+
+ w.tot += uint64(l)
+ return l, w.err
+}
+
+func (w *Writer) WriteUint32(v uint32) (int, error) {
+ if w.err != nil {
+ return 0, w.err
+ }
+ w.b[0] = byte(v >> 24)
+ w.b[1] = byte(v >> 16)
+ w.b[2] = byte(v >> 8)
+ w.b[3] = byte(v)
+
+ var l int
+ l, w.err = w.w.Write(w.b[:4])
+ w.tot += uint64(l)
+ return l, w.err
+}
+
+func (w *Writer) WriteUint64(v uint64) (int, error) {
+ if w.err != nil {
+ return 0, w.err
+ }
+ w.b[0] = byte(v >> 56)
+ w.b[1] = byte(v >> 48)
+ w.b[2] = byte(v >> 40)
+ w.b[3] = byte(v >> 32)
+ w.b[4] = byte(v >> 24)
+ w.b[5] = byte(v >> 16)
+ w.b[6] = byte(v >> 8)
+ w.b[7] = byte(v)
+
+ var l int
+ l, w.err = w.w.Write(w.b[:8])
+ w.tot += uint64(l)
+ return l, w.err
+}
+
+func (w *Writer) Tot() uint64 {
+ return w.tot
+}
+
+func (w *Writer) Err() error {
+ return w.err
+}
diff --git a/xdr/xdr_test.go b/xdr/xdr_test.go
new file mode 100644
index 0000000000..859958ef88
--- /dev/null
+++ b/xdr/xdr_test.go
@@ -0,0 +1,57 @@
+package xdr
+
+import (
+ "bytes"
+ "testing"
+ "testing/quick"
+)
+
+func TestPad(t *testing.T) {
+ tests := [][]int{
+ {0, 0},
+ {1, 3},
+ {2, 2},
+ {3, 1},
+ {4, 0},
+ {32, 0},
+ {33, 3},
+ }
+ for _, tc := range tests {
+ if p := pad(tc[0]); p != tc[1] {
+ t.Errorf("Incorrect padding for %d bytes, %d != %d", tc[0], p, tc[1])
+ }
+ }
+}
+
+func TestBytesNil(t *testing.T) {
+ fn := func(bs []byte) bool {
+ var b = new(bytes.Buffer)
+ var w = NewWriter(b)
+ var r = NewReader(b)
+ w.WriteBytes(bs)
+ w.WriteBytes(bs)
+ r.ReadBytes(nil)
+ res := r.ReadBytes(nil)
+ return bytes.Compare(bs, res) == 0
+ }
+ if err := quick.Check(fn, nil); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestBytesGiven(t *testing.T) {
+ fn := func(bs []byte) bool {
+ var b = new(bytes.Buffer)
+ var w = NewWriter(b)
+ var r = NewReader(b)
+ w.WriteBytes(bs)
+ w.WriteBytes(bs)
+ res := make([]byte, 12)
+ res = r.ReadBytes(res)
+ res = r.ReadBytes(res)
+ return bytes.Compare(bs, res) == 0
+ }
+ if err := quick.Check(fn, nil); err != nil {
+ t.Error(err)
+ }
+}