diff --git a/nbt/decode.go b/nbt/decode.go index 5a63e20..bd90e64 100644 --- a/nbt/decode.go +++ b/nbt/decode.go @@ -10,31 +10,32 @@ import ( ) func Unmarshal(data []byte, v interface{}) error { - return NewDecoder(bytes.NewReader(data)).Decode(v) + _, err := NewDecoder(bytes.NewReader(data)).Decode(v) + return err } -func (d *Decoder) Decode(v interface{}) error { +func (d *Decoder) Decode(v interface{}) (string, error) { val := reflect.ValueOf(v) if val.Kind() != reflect.Ptr { - return errors.New("nbt: non-pointer passed to Unmarshal") + return "", errors.New("nbt: non-pointer passed to Decode") } //start read NBT tagType, tagName, err := d.readTag() if err != nil { - return fmt.Errorf("nbt: %w", err) + return tagName, fmt.Errorf("nbt: %w", err) } if c := d.checkCompressed(tagType); c != "" { - return fmt.Errorf("nbt: unknown Tag, maybe need %s", c) + return tagName, fmt.Errorf("nbt: unknown Tag, maybe need %s", c) } - // We decode val not val.Elem because the Unmarshaler interface + // We decode val not val.Elem because the NBTDecoder interface // test must be applied at the top level of the value. err = d.unmarshal(val, tagType, tagName) if err != nil { - return fmt.Errorf("nbt: fail to decode tag %q: %w", tagName, err) + return tagName, fmt.Errorf("nbt: fail to decode tag %q: %w", tagName, err) } - return nil + return tagName, nil } // check the first byte and return if it use compress @@ -53,7 +54,7 @@ var ErrEND = errors.New("unexpected TAG_End") func (d *Decoder) unmarshal(val reflect.Value, tagType byte, tagName string) error { u, val := indirect(val, tagType == TagEnd) if u != nil { - return u.Unmarshal(tagType, tagName, d.r) + return u.Decode(tagType, d.r) } switch tagType { @@ -357,12 +358,12 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte, tagName string) err // indirect walks down v allocating pointers as needed, // until it gets to a non-pointer. -// If it encounters an Unmarshaler, indirect stops and returns that. +// If it encounters an NBTDecoder, indirect stops and returns that. // If decodingNull is true, indirect stops at the first settable pointer so it // can be set to nil. // // This function is copied and modified from encoding/json -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { +func indirect(v reflect.Value, decodingNull bool) (NBTDecoder, reflect.Value) { v0 := v haveAddr := false @@ -404,7 +405,7 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { v.Set(reflect.New(v.Type().Elem())) } if v.Type().NumMethod() > 0 && v.CanInterface() { - if u, ok := v.Interface().(Unmarshaler); ok { + if u, ok := v.Interface().(NBTDecoder); ok { return u, reflect.Value{} } } diff --git a/nbt/decode_test.go b/nbt/decode_test.go index 7b00740..67da161 100644 --- a/nbt/decode_test.go +++ b/nbt/decode_test.go @@ -182,14 +182,14 @@ func MakeBigTestStruct() BigTestStruct { return want } -func TestUnmarshal_bigTest(t *testing.T) { +func TestDecoder_Decode_bigTest(t *testing.T) { //test parse var value BigTestStruct r, err := gzip.NewReader(bytes.NewReader(bigTestData[:])) if err != nil { t.Fatal(err) } - if err := NewDecoder(r).Decode(&value); err != nil { + if _, err := NewDecoder(r).Decode(&value); err != nil { t.Fatal(err) } @@ -204,7 +204,7 @@ func TestUnmarshal_bigTest(t *testing.T) { if err != nil { t.Fatal(err) } - if err := NewDecoder(r).Decode(&empty); err != nil { + if _, err := NewDecoder(r).Decode(&empty); err != nil { t.Fatal(err) } @@ -213,20 +213,20 @@ func TestUnmarshal_bigTest(t *testing.T) { if err != nil { t.Fatal(err) } - if err := NewDecoder(r).Decode(&inf); err != nil { + if _, err := NewDecoder(r).Decode(&inf); err != nil { t.Fatal(err) } // t.Log(inf) } -func BenchmarkUnmarshal_bigTest(b *testing.B) { +func BenchmarkDecoder_Decode_bigTest(b *testing.B) { var value BigTestStruct for i := 0; i < b.N; i++ { r, err := gzip.NewReader(bytes.NewReader(bigTestData[:])) if err != nil { b.Fatal(err) } - if err := NewDecoder(r).Decode(&value); err != nil { + if _, err := NewDecoder(r).Decode(&value); err != nil { b.Fatal(err) } } @@ -256,7 +256,7 @@ func TestDecoder_overRead(t *testing.T) { r := bytes.NewReader(enc) // Count read bytes by using io.LimitReader rr := io.LimitReader(r, math.MaxInt64).(*io.LimitedReader) - if err := NewDecoder(rr).Decode(&value); err != nil { + if _, err := NewDecoder(rr).Decode(&value); err != nil { t.Fatal(err) } @@ -266,7 +266,7 @@ func TestDecoder_overRead(t *testing.T) { } } -func TestUnmarshal_IntArray(t *testing.T) { +func TestDecoder_Decode_IntArray(t *testing.T) { data := []byte{ TagIntArray, 0, 0, 0, 0, 0, 3, @@ -300,7 +300,7 @@ func TestUnmarshal_IntArray(t *testing.T) { // t.Log(value, value2) } -func TestUnmarshal_LongArray(t *testing.T) { +func TestDecoder_Decode_LongArray(t *testing.T) { data := []byte{ TagLongArray, 0, 0, 0, 0, 0, 3, @@ -332,7 +332,7 @@ func TestUnmarshal_LongArray(t *testing.T) { // t.Log(infValue) } -func TestUnmarshal_ByteArray(t *testing.T) { +func TestDecoder_Decode_ByteArray(t *testing.T) { data := []byte{ TagByteArray, 0, 0, 0, 0, 0, 7, @@ -363,7 +363,7 @@ func TestUnmarshal_ByteArray(t *testing.T) { // t.Log(infValue) } -func TestUnmarshal_ErrorString(t *testing.T) { +func TestDecoder_Decode_ErrorString(t *testing.T) { var data = []byte{ 0x08, 0x00, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0xFF, 0xFE, 0x42, 0x61, 0x6e, 0x61, 0x6e, 0x72, 0x61, 0x6d, 0x61, @@ -379,3 +379,32 @@ func TestUnmarshal_ErrorString(t *testing.T) { t.Log(err) } + +func TestDecoder_Decode_rawMessage(t *testing.T) { + data := []byte{ + TagCompound, 0, 2, 'a', 'b', + TagInt, 0, 3, 'K', 'e', 'y', 0, 0, 0, 12, + TagString, 0, 5, 'V', 'a', 'l', 'u', 'e', 0, 4, 'T', 'n', 'z', 'e', + TagEnd, + } + var container struct { + Key int32 + Value RawMessage + } + + if tag, err := NewDecoder(bytes.NewReader(data)).Decode(&container); err != nil { + t.Fatal(tag) + } else { + if tag != "ab" { + t.Fatalf("Decode tag name error: want %s, get: %s", "ab", tag) + } + if container.Key != 12 { + t.Fatalf("Decode Key error: want %v, get: %v", 12, container.Key) + } + if !bytes.Equal(container.Value.Data, []byte{ + 0, 4, 'T', 'n', 'z', 'e', + }) { + t.Fatalf("Decode Key error: get: %v", container.Value) + } + } +} diff --git a/nbt/encode.go b/nbt/encode.go index 29e5044..5bf92b6 100644 --- a/nbt/encode.go +++ b/nbt/encode.go @@ -11,9 +11,9 @@ import ( "strings" ) -func Marshal(v interface{}, optionalTagName ...string) ([]byte, error) { +func Marshal(v interface{}) ([]byte, error) { var buf bytes.Buffer - err := NewEncoder(&buf).Encode(v, optionalTagName...) + err := NewEncoder(&buf).Encode(v, "") return buf.Bytes(), err } @@ -25,25 +25,31 @@ func NewEncoder(w io.Writer) *Encoder { return &Encoder{w: w} } -func (e *Encoder) Encode(v interface{}, optionalTagName ...string) error { +func (e *Encoder) Encode(v interface{}, tagName string) error { val := reflect.ValueOf(v) - var tagName string - if len(optionalTagName) > 0 { - tagName = optionalTagName[0] - } - return e.marshal(val, getTagType(val.Type()), tagName) + return e.marshal(val, getTagType(val), tagName) } func (e *Encoder) marshal(val reflect.Value, tagType byte, tagName string) error { if err := e.writeHeader(val, tagType, tagName); err != nil { return err } + if val.CanInterface() { + if encoder, ok := val.Interface().(NBTEncoder); ok { + return encoder.Encode(e.w) + } + } return e.writeValue(val, tagType) } func (e *Encoder) writeHeader(val reflect.Value, tagType byte, tagName string) (err error) { if tagType == TagList { - eleType := getTagType(val.Type().Elem()) + var eleType byte + if val.Len() > 0 { + eleType = getTagType(val.Index(0)) + } else { + eleType = getTagTypeByType(val.Type().Elem()) + } err = e.writeListHeader(eleType, tagName, val.Len(), true) } else { err = e.writeTag(tagType, tagName) @@ -96,7 +102,7 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { case TagList: for i := 0; i < val.Len(); i++ { arrVal := val.Index(i) - err := e.writeValue(arrVal, getTagType(arrVal.Type())) + err := e.writeValue(arrVal, getTagType(arrVal)) if err != nil { return err } @@ -119,12 +125,13 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { n := val.NumField() for i := 0; i < n; i++ { f := val.Type().Field(i) + v := val.Field(i) tag := f.Tag.Get("nbt") if (f.PkgPath != "" && !f.Anonymous) || tag == "-" { continue // Private field } - tagProps := parseTag(f, tag) + tagProps := parseTag(f, v, tag) if err := e.marshal(val.Field(i), tagProps.Type, tagProps.Name); err != nil { return err } @@ -139,7 +146,7 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { tagName = r.Key().String() } tagValue := r.Value() - tagType := getTagType(tagValue.Type()) + tagType := getTagType(tagValue) if tagType == TagNone { return errors.New("unsupported value " + tagValue.String()) } @@ -156,7 +163,36 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { return nil } -func getTagType(vk reflect.Type) byte { +func getTagType(v reflect.Value) byte { + if v.CanInterface() { + if encoder, ok := v.Interface().(NBTEncoder); ok { + return encoder.TagType() + } + } + switch v.Kind() { + case reflect.Array, reflect.Slice: + var elemType byte + if v.Len() > 0 { + elemType = getTagType(v.Index(0)) + } else { + elemType = getTagTypeByType(v.Type().Elem()) + } + switch elemType { + case TagByte: // Special types for these values + return TagByteArray + case TagInt: + return TagIntArray + case TagLong: + return TagLongArray + default: + return TagList + } + default: + return getTagTypeByType(v.Type()) + } +} + +func getTagTypeByType(vk reflect.Type) byte { switch vk.Kind() { case reflect.Uint8: return TagByte @@ -174,17 +210,6 @@ func getTagType(vk reflect.Type) byte { return TagString case reflect.Struct, reflect.Interface, reflect.Map: return TagCompound - case reflect.Array, reflect.Slice: - switch vk.Elem().Kind() { - case reflect.Uint8: // Special types for these values - return TagByteArray - case reflect.Int32: - return TagIntArray - case reflect.Int64: - return TagLongArray - default: - return TagList - } default: return TagNone } @@ -195,7 +220,7 @@ type tagProps struct { Type byte } -func parseTag(f reflect.StructField, tagName string) tagProps { +func parseTag(f reflect.StructField, v reflect.Value, tagName string) tagProps { result := tagProps{} result.Name = tagName if result.Name == "" { @@ -203,7 +228,7 @@ func parseTag(f reflect.StructField, tagName string) tagProps { } nbtType := f.Tag.Get("nbt_type") - result.Type = getTagType(f.Type) + result.Type = getTagType(v) if strings.Contains(nbtType, "list") { if IsArrayTag(result.Type) { result.Type = TagList // for expanding the array to a standard list diff --git a/nbt/encode_test.go b/nbt/encode_test.go index ea40e81..8014e5a 100644 --- a/nbt/encode_test.go +++ b/nbt/encode_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func TestMarshal_IntArray(t *testing.T) { +func TestEncoder_Encode_intArray(t *testing.T) { // Test marshal pure Int array v := []int32{0, -10, 3} out := []byte{TagIntArray, 0x00, 0x00, 0, 0, 0, 3, @@ -41,7 +41,7 @@ func TestMarshal_IntArray(t *testing.T) { } } -func TestMarshal_FloatArray(t *testing.T) { +func TestEncoder_Encode_floatArray(t *testing.T) { // Test marshal pure Int array v := []float32{0.3, -100, float32(math.NaN())} out := []byte{TagList, 0x00, 0x00, TagFloat, 0, 0, 0, 3, @@ -56,7 +56,7 @@ func TestMarshal_FloatArray(t *testing.T) { } } -func TestMarshal_String(t *testing.T) { +func TestEncoder_Encode_string(t *testing.T) { v := "Test" out := []byte{TagString, 0x00, 0x00, 0, 4, 'T', 'e', 's', 't'} @@ -68,7 +68,7 @@ func TestMarshal_String(t *testing.T) { } } -func TestMarshal_InterfaceArray(t *testing.T) { +func TestEncoder_Encode_interfaceArray(t *testing.T) { type Struct1 struct { Val int32 } @@ -110,7 +110,7 @@ func TestMarshal_InterfaceArray(t *testing.T) { } } -func TestMarshal_StructArray(t *testing.T) { +func TestEncoder_Encode_structArray(t *testing.T) { type Struct1 struct { Val int32 } @@ -164,9 +164,9 @@ func TestMarshal_StructArray(t *testing.T) { } } -func TestMarshal_bigTest(t *testing.T) { - data, err := Marshal(MakeBigTestStruct(), "Level") - if err != nil { +func TestEncoder_Encode_bigTest(t *testing.T) { + var buf bytes.Buffer + if err := NewEncoder(&buf).Encode(MakeBigTestStruct(), "Level"); err != nil { t.Error(err) } @@ -176,12 +176,12 @@ func TestMarshal_bigTest(t *testing.T) { t.Error(err) } - if !bytes.Equal(data, want) { - t.Errorf("got:\n[% 2x]\nwant:\n[% 2x]", data, want) + if !bytes.Equal(buf.Bytes(), want) { + t.Errorf("got:\n[% 2x]\nwant:\n[% 2x]", buf.Bytes(), want) } } -func TestMarshal_map(t *testing.T) { +func TestEncoder_Encode_map(t *testing.T) { v := map[string][]int32{ "Tnze": {1, 2, 3, 4, 5}, "Xi_Xi_Mi": {0, 0, 4, 7, 2}, @@ -197,7 +197,7 @@ func TestMarshal_map(t *testing.T) { XXM []int32 `nbt:"Xi_Xi_Mi"` } - if err := NewDecoder(bytes.NewReader(b)).Decode(&data); err != nil { + if _, err := NewDecoder(bytes.NewReader(b)).Decode(&data); err != nil { t.Fatal(err) } if !reflect.DeepEqual(data.Tnze, v["Tnze"]) { @@ -207,3 +207,26 @@ func TestMarshal_map(t *testing.T) { t.Fatalf("Marshal map error: got: %#v, want %#v", data.XXM, v["Xi_Xi_Mi"]) } } + +func TestEncoder_Encode_rawMessage(t *testing.T) { + data := []byte{ + TagCompound, 0, 2, 'a', 'b', + TagInt, 0, 3, 'K', 'e', 'y', 0, 0, 0, 12, + TagString, 0, 5, 'V', 'a', 'l', 'u', 'e', 0, 4, 'T', 'n', 'z', 'e', + TagEnd, + } + var container struct { + Key int32 + Value RawMessage + } + container.Key = 12 + container.Value.Type = TagString + container.Value.Data = []byte{0, 4, 'T', 'n', 'z', 'e'} + + var buf bytes.Buffer + if err := NewEncoder(&buf).Encode(container, "ab"); err != nil { + t.Fatalf("Encode error: %v", err) + } else if !bytes.Equal(data, buf.Bytes()) { + t.Fatalf("Encode error: want %v, get: %v", data, buf.Bytes()) + } +} diff --git a/nbt/interface.go b/nbt/interface.go index e05af2c..386fb9a 100644 --- a/nbt/interface.go +++ b/nbt/interface.go @@ -1,9 +1,12 @@ package nbt -type Unmarshaler interface { - Unmarshal(tagType byte, tagName string, r DecoderReader) error +import "io" + +type NBTDecoder interface { + Decode(tagType byte, r DecoderReader) error } -//type Marshaller interface{ -// Marshal() -//} +type NBTEncoder interface { + TagType() byte + Encode(w io.Writer) error +} diff --git a/nbt/raw.go b/nbt/raw.go deleted file mode 100644 index f599038..0000000 --- a/nbt/raw.go +++ /dev/null @@ -1,23 +0,0 @@ -package nbt - -import ( - "bytes" - "io" -) - -type RawMessage []byte - -func (m *RawMessage) Unmarshal(tagType byte, _ string, r DecoderReader) error { - if tagType == TagEnd { - return ErrEND - } - - buf := bytes.NewBuffer((*m)[:0]) - tee := io.TeeReader(r, buf) - err := NewDecoder(tee).rawRead(tagType) - if err != nil { - return err - } - *m = buf.Bytes() - return nil -} diff --git a/nbt/rawmsg.go b/nbt/rawmsg.go new file mode 100644 index 0000000..a80f81b --- /dev/null +++ b/nbt/rawmsg.go @@ -0,0 +1,34 @@ +package nbt + +import ( + "bytes" + "io" +) + +type RawMessage struct { + Type byte + Data []byte +} + +func (m RawMessage) TagType() byte { + return m.Type +} + +func (m RawMessage) Encode(w io.Writer) error { + _, err := w.Write(m.Data) + return err +} + +func (m *RawMessage) Decode(tagType byte, r DecoderReader) error { + if tagType == TagEnd { + return ErrEND + } + buf := bytes.NewBuffer(m.Data[:0]) + tee := io.TeeReader(r, buf) + err := NewDecoder(tee).rawRead(tagType) + if err != nil { + return err + } + m.Data = buf.Bytes() + return nil +} diff --git a/nbt/snbt.go b/nbt/snbt.go new file mode 100644 index 0000000..7f40fe2 --- /dev/null +++ b/nbt/snbt.go @@ -0,0 +1,10 @@ +package nbt + +type StringifiedNBT struct { + Name string + Content string +} + +func (n *StringifiedNBT) Decode(tagType byte, tagName string, r DecoderReader) error { + panic("unimplemented") +} diff --git a/net/packet/types.go b/net/packet/types.go index bc78c7f..31e08ef 100644 --- a/net/packet/types.go +++ b/net/packet/types.go @@ -454,7 +454,7 @@ func (n nbtField) WriteTo(w io.Writer) (int64, error) { func (n nbtField) ReadFrom(r io.Reader) (int64, error) { // LimitReader is used to count reader length lr := &io.LimitedReader{R: r, N: math.MaxInt64} - err := nbt.NewDecoder(lr).Decode(n.V) + _, err := nbt.NewDecoder(lr).Decode(n.V) if err != nil && errors.Is(err, nbt.ErrEND) { err = nil } diff --git a/save/chunk.go b/save/chunk.go index bb5e3d1..46356a6 100644 --- a/save/chunk.go +++ b/save/chunk.go @@ -66,6 +66,6 @@ func (c *Column) Load(data []byte) (err error) { return err } - err = nbt.NewDecoder(r).Decode(c) + _, err = nbt.NewDecoder(r).Decode(c) return } diff --git a/save/level.go b/save/level.go index 63d1aa2..8270ea5 100644 --- a/save/level.go +++ b/save/level.go @@ -65,6 +65,6 @@ type Level struct { } func ReadLevel(r io.Reader) (data Level, err error) { - err = nbt.NewDecoder(r).Decode(&data) + _, err = nbt.NewDecoder(r).Decode(&data) return } diff --git a/save/playerdata.go b/save/playerdata.go index 8982120..36cd16f 100644 --- a/save/playerdata.go +++ b/save/playerdata.go @@ -79,7 +79,7 @@ type Item struct { } func ReadPlayerData(r io.Reader) (data PlayerData, err error) { - err = nbt.NewDecoder(r).Decode(&data) + _, err = nbt.NewDecoder(r).Decode(&data) //parse UUID from two int64s binary.BigEndian.PutUint64(data.UUID[:], uint64(data.UUIDMost)) binary.BigEndian.PutUint64(data.UUID[8:], uint64(data.UUIDLeast)) diff --git a/save/region/mca_test.go b/save/region/mca_test.go index 2ca400c..b7a78fb 100644 --- a/save/region/mca_test.go +++ b/save/region/mca_test.go @@ -40,7 +40,7 @@ func TestReadRegion(t *testing.T) { t.Error(err) } var b interface{} - err = nbt.NewDecoder(r).Decode(&b) + _, err = nbt.NewDecoder(r).Decode(&b) if err != nil { t.Error(err) }