diff --git a/nbt/encode.go b/nbt/encode.go index fbc49b2..6f3836f 100644 --- a/nbt/encode.go +++ b/nbt/encode.go @@ -41,7 +41,7 @@ func (e *Encoder) marshal(val reflect.Value, tagType byte, tagName string) error func (e *Encoder) writeHeader(val reflect.Value, tagType byte, tagName string) (err error) { if tagType == TagList { eleType := getTagType(val.Type().Elem()) - err = e.writeListHeader(eleType, tagName, val.Len()) + err = e.writeListHeader(eleType, tagName, val.Len(), true) } else { err = e.writeTag(tagType, tagName) } @@ -224,9 +224,11 @@ func (e *Encoder) writeTag(tagType byte, tagName string) error { return err } -func (e *Encoder) writeListHeader(elementType byte, tagName string, n int) (err error) { - if err = e.writeTag(TagList, tagName); err != nil { - return +func (e *Encoder) writeListHeader(elementType byte, tagName string, n int, writeTag bool) (err error) { + if writeTag { + if err = e.writeTag(TagList, tagName); err != nil { + return + } } if _, err = e.w.Write([]byte{elementType}); err != nil { return diff --git a/nbt/snbt_decode.go b/nbt/snbt_decode.go index 0ffa2c8..21ee862 100644 --- a/nbt/snbt_decode.go +++ b/nbt/snbt_decode.go @@ -39,7 +39,8 @@ func writeValue(e *Encoder, d *decodeState, tagName string) error { e.writeTag(TagCompound, tagName) return writeCompoundPayload(e, d) case scanBeginList: - return writeListOrArray(e, d, tagName) + _, err := writeListOrArray(e, d, true, tagName) + return err } } @@ -102,11 +103,11 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error { return nil } -func writeListOrArray(e *Encoder, d *decodeState, tagName string) error { +func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) (tagType byte, err error) { d.scanWhile(scanSkipSpace) if d.opcode == scanEndValue { // ']', empty TAG_List - e.writeListHeader(TagEnd, tagName, 0) - return nil + e.writeListHeader(TagEnd, tagName, 0, writeTag) + return TagList, nil } // We don't know the length of the List, @@ -124,14 +125,19 @@ func writeListOrArray(e *Encoder, d *decodeState, tagName string) error { var elemType byte switch literal[0] { case 'B': - e.writeTag(TagByteArray, tagName) + tagType = TagByteArray elemType = TagByte case 'I': - e.writeTag(TagIntArray, tagName) + tagType = TagIntArray elemType = TagInt case 'L': - e.writeTag(TagLongArray, tagName) + tagType = TagLongArray elemType = TagLong + default: + return TagList, errors.New("unknown Array type") + } + if writeTag { + e.writeTag(tagType, tagName) } for { d.scanNext() @@ -142,7 +148,7 @@ func writeListOrArray(e *Encoder, d *decodeState, tagName string) error { break } if d.opcode != scanBeginLiteral { - return errors.New("not literal in Array") + return tagType, errors.New("not literal in Array") } start := d.readIndex() @@ -150,7 +156,7 @@ func writeListOrArray(e *Encoder, d *decodeState, tagName string) error { literal := d.data[start:d.readIndex()] tagType, litVal := parseLiteral(literal) if tagType != elemType { - return errors.New("unexpected element type in TAG_Array") + return tagType, errors.New("unexpected element type in TAG_Array") } switch elemType { case TagByte: @@ -176,7 +182,7 @@ func writeListOrArray(e *Encoder, d *decodeState, tagName string) error { tagType = t } if t != tagType { - return errors.New("different TagType in List") + return TagList, errors.New("different TagType in List") } writeLiteralPayload(e2, v) count++ @@ -196,10 +202,40 @@ func writeListOrArray(e *Encoder, d *decodeState, tagName string) error { d.scanWhile(scanContinue) literal = d.data[start:d.readIndex()] } - e.writeListHeader(tagType, tagName, count) + e.writeListHeader(tagType, tagName, count, writeTag) e.w.Write(buf.Bytes()) case scanBeginList: // TAG_List - e.writeListHeader(TagList, tagName, count) + var elemType byte + for { + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanBeginList { + return TagList, errors.New("different TagType in List") + } + elemType, err = writeListOrArray(e2, d, false, "") + if err != nil { + return tagType, err + } + count++ + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + // read ',' or ']' + d.scanNext() + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndValue { + break + } + if d.opcode != scanListValue { + panic(phasePanicMsg) + } + // read '[' + d.scanNext() + } + e.writeListHeader(elemType, tagName, count, writeTag) e.w.Write(buf.Bytes()) case scanBeginCompound: // TAG_List for { @@ -207,7 +243,7 @@ func writeListOrArray(e *Encoder, d *decodeState, tagName string) error { d.scanWhile(scanSkipSpace) } if d.opcode != scanBeginCompound { - return errors.New("different TagType in List") + return TagList, errors.New("different TagType in List") } writeCompoundPayload(e2, d) count++ @@ -228,10 +264,10 @@ func writeListOrArray(e *Encoder, d *decodeState, tagName string) error { // read '{' d.scanNext() } - e.writeListHeader(TagCompound, tagName, count) + e.writeListHeader(TagCompound, tagName, count, writeTag) e.w.Write(buf.Bytes()) } - return nil + return } // readIndex returns the position of the last byte read. diff --git a/nbt/snbt_decode_test.go b/nbt/snbt_decode_test.go index 57d187d..c9a873c 100644 --- a/nbt/snbt_decode_test.go +++ b/nbt/snbt_decode_test.go @@ -33,14 +33,14 @@ func TestEncoder_WriteSNBT(t *testing.T) { {`[a,"b",'c']`, []byte{9, 0, 0, 8, 0, 0, 0, 3, 0, 1, 'a', 0, 1, 'b', 0, 1, 'c'}}, {`[{},{a:1b},{}]`, []byte{9, 0, 0, 10, 0, 0, 0, 3, 0, 1, 0, 1, 'a', 1, 0, 0}}, {`[ { } , { a : 1b } , { } ] `, []byte{9, 0, 0, 10, 0, 0, 0, 3, 0, 1, 0, 1, 'a', 1, 0, 0}}, - {`[[],[]]`, []byte{9, 0, 0, 9, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0}}, + {`[[],[]]`, []byte{9, 0, 0, 9, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, {`[B;]`, []byte{7, 0, 0, 0, 0, 0, 0}}, - {`[B;1,2,3]`, []byte{7, 0, 0, 0, 0, 0, 3, 1, 2, 3}}, + {`[B;1b,2B,3B]`, []byte{7, 0, 0, 0, 0, 0, 3, 1, 2, 3}}, {`[I;]`, []byte{11, 0, 0, 0, 0, 0, 0}}, {`[I;1,2,3]`, []byte{11, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3}}, {`[L;]`, []byte{12, 0, 0, 0, 0, 0, 0}}, - {`[L;1,2,3]`, []byte{12, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3}}, + {`[L;1L,2L,3L]`, []byte{12, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3}}, } for i := range testCases { buf.Reset()