From 610cb0c7d59fb685f2c7d935df9aeae65c8f116c Mon Sep 17 00:00:00 2001 From: Mark Asp Date: Mon, 7 Sep 2020 17:45:44 -0500 Subject: [PATCH] add marshaling of struct and interface types for nbt --- nbt/marshal.go | 69 ++++++++++++++++++------------------- nbt/marshal_test.go | 84 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 36 deletions(-) diff --git a/nbt/marshal.go b/nbt/marshal.go index a3da95e..498e5e1 100644 --- a/nbt/marshal.go +++ b/nbt/marshal.go @@ -67,6 +67,7 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error { return e.writeInt64(int64(math.Float64bits(val.Float()))) case reflect.Array, reflect.Slice: + n := val.Len() switch val.Type().Elem().Kind() { case reflect.Uint8: // []byte if err := e.writeTag(TagByteArray, tagName); err != nil { @@ -82,7 +83,6 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error { if err := e.writeTag(TagIntArray, tagName); err != nil { return err } - n := val.Len() if err := e.writeInt32(int32(n)); err != nil { return err } @@ -96,7 +96,6 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error { if err := e.writeTag(TagLongArray, tagName); err != nil { return err } - n := val.Len() if err := e.writeInt32(int32(n)); err != nil { return err } @@ -107,14 +106,7 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error { } case reflect.Int16: - if err := e.writeTag(TagList, tagName); err != nil { - return err - } - if _, err := e.w.Write([]byte{TagShort}); err != nil { - return err - } - n := val.Len() - if err := e.writeInt32(int32(n)); err != nil { + if err := e.writeListHeader(TagShort, tagName, val.Len()); err != nil { return err } for i := 0; i < n; i++ { @@ -124,14 +116,7 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error { } case reflect.Float32: - if err := e.writeTag(TagList, tagName); err != nil { - return err - } - if _, err := e.w.Write([]byte{TagFloat}); err != nil { - return err - } - n := val.Len() - if err := e.writeInt32(int32(n)); err != nil { + if err := e.writeListHeader(TagFloat, tagName, val.Len()); err != nil { return err } for i := 0; i < n; i++ { @@ -141,14 +126,7 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error { } case reflect.Float64: - if err := e.writeTag(TagList, tagName); err != nil { - return err - } - if _, err := e.w.Write([]byte{TagFloat}); err != nil { - return err - } - n := val.Len() - if err := e.writeInt32(int32(n)); err != nil { + if err := e.writeListHeader(TagDouble, tagName, val.Len()); err != nil { return err } for i := 0; i < n; i++ { @@ -158,15 +136,7 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error { } case reflect.String: - if err := e.writeTag(TagList, tagName); err != nil { - return err - } - if _, err := e.w.Write([]byte{TagString}); err != nil { - return err - } - n := val.Len() - // Write length of strings - if err := e.writeInt32(int32(n)); err != nil { + if err := e.writeListHeader(TagString, tagName, n); err != nil { return err } for i := 0; i < n; i++ { @@ -180,7 +150,20 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error { return err } } - + case reflect.Struct, reflect.Interface: + if err := e.writeListHeader(TagCompound, tagName, n); err != nil { + return err + } + for i := 0; i < n; i++ { + elemVal := val.Index(i) + if val.Type().Elem().Kind() == reflect.Interface { + elemVal = reflect.ValueOf(elemVal.Interface()) + } + err := e.marshal(elemVal, "") + if err != nil { + return err + } + } default: return errors.New("unknown type " + val.Type().String() + " slice") } @@ -236,6 +219,20 @@ func (e *Encoder) writeTag(tagType byte, tagName string) error { return err } +func (e *Encoder) writeListHeader(elementType byte, tagName string, n int) (err error) { + if err = e.writeTag(TagList, tagName); err != nil { + return + } + if _, err = e.w.Write([]byte{elementType}); err != nil { + return + } + // Write length of strings + if err = e.writeInt32(int32(n)); err != nil { + return + } + return nil +} + func (e *Encoder) writeNamelessTag(tagType byte, tagName string) error { _, err := e.w.Write([]byte{tagType}) return err diff --git a/nbt/marshal_test.go b/nbt/marshal_test.go index b1db812..b253563 100644 --- a/nbt/marshal_test.go +++ b/nbt/marshal_test.go @@ -55,3 +55,87 @@ func TestMarshal_FloatArray(t *testing.T) { t.Errorf("output binary not right: get % 02x, want % 02x ", buf.Bytes(), out) } } + +func TestMarshal_InterfaceArray(t *testing.T) { + type Struct1 struct { + Val int32 + } + + type Struct2 struct { + Val float32 + } + + tests := []struct { + name string + args []interface{} + want []byte + }{ + { + name: "Two element interface array", + args: []interface{}{Struct1{3}, Struct2{0.3}}, + want: []byte{ + TagList, 0x00, 0x00 /*no name*/, TagCompound, 0, 0, 0, 2, + // 1st element + TagCompound, 0x00, 0x00, /*no name*/ + TagInt, 0x00, 0x03, 'V', 'a', 'l', 0x00, 0x00, 0x00, 0x03, // 3 + TagEnd, + // 2nd element + TagCompound, 0x00, 0x00, /*no name*/ + TagFloat, 0x00, 0x03, 'V', 'a', 'l', 0x3e, 0x99, 0x99, 0x9a, // 0.3 + TagEnd, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := Marshal(w, tt.args) + if err != nil { + t.Error(err) + } else if !bytes.Equal(w.Bytes(), tt.want) { + t.Errorf("Marshal([]interface{}) got = % 02x, want % 02x", w.Bytes(), tt.want) + return + } + }) + } +} + +func TestMarshal_StructArray(t *testing.T) { + type Struct1 struct { + Val int32 + } + + tests := []struct { + name string + args []Struct1 + want []byte + }{ + { + name: "One element struct array", + args: []Struct1{{3}, {-10}}, + want: []byte{ + TagList, 0x00, 0x00 /*no name*/, TagCompound, 0, 0, 0, 2, + // 1st element + TagCompound, 0x00, 0x00, /*no name*/ + TagInt, 0x00, 0x03, 'V', 'a', 'l', 0x00, 0x00, 0x00, 0x03, // 3 + TagEnd, + // 2nd element + TagCompound, 0x00, 0x00, /*no name*/ + TagInt, 0x00, 0x03, 'V', 'a', 'l', 0xff, 0xff, 0xff, 0xf6, // -10 + TagEnd, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := Marshal(w, tt.args) + if err != nil { + t.Error(err) + } else if !bytes.Equal(w.Bytes(), tt.want) { + t.Errorf("Marshal([]struct{}) got = % 02x, want % 02x", w.Bytes(), tt.want) + return + } + }) + } +}