From cbf5a7c053d687805ef959575e3c629e405c0bfd Mon Sep 17 00:00:00 2001 From: Tnze Date: Mon, 24 Apr 2023 01:04:21 +0800 Subject: [PATCH] Fix the issue of snbt decoding when Tag*Arrays in TagCompound --- nbt/decode.go | 50 ++++++------- nbt/snbt.go | 22 +++--- nbt/snbt_decode.go | 160 ++++++++++++++++++++++------------------ nbt/snbt_decode_test.go | 2 + nbt/snbt_scanner.go | 2 +- 5 files changed, 128 insertions(+), 108 deletions(-) diff --git a/nbt/decode.go b/nbt/decode.go index f1bd92d..1533176 100644 --- a/nbt/decode.go +++ b/nbt/decode.go @@ -79,7 +79,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { return ErrEND case TagByte: - value, err := d.readByte() + value, err := d.readInt8() if err != nil { return err } @@ -97,7 +97,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagShort: - value, err := d.readShort() + value, err := d.readInt16() if err != nil { return err } @@ -113,7 +113,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagInt: - value, err := d.readInt() + value, err := d.readInt32() if err != nil { return err } @@ -129,7 +129,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagFloat: - vInt, err := d.readInt() + vInt, err := d.readInt32() if err != nil { return err } @@ -146,7 +146,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagLong: - value, err := d.readLong() + value, err := d.readInt64() if err != nil { return err } @@ -162,7 +162,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagDouble: - vInt, err := d.readLong() + vInt, err := d.readInt64() if err != nil { return err } @@ -199,7 +199,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagByteArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } @@ -242,7 +242,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagIntArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } @@ -262,7 +262,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { buf = reflect.MakeSlice(vt, int(aryLen), int(aryLen)) } for i := 0; i < int(aryLen); i++ { - value, err := d.readInt() + value, err := d.readInt32() if err != nil { return err } @@ -273,7 +273,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagLongArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } @@ -287,7 +287,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { case reflect.Int64: buf := reflect.MakeSlice(vt, int(aryLen), int(aryLen)) for i := 0; i < int(aryLen); i++ { - value, err := d.readLong() + value, err := d.readInt64() if err != nil { return err } @@ -297,7 +297,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { case reflect.Uint64: buf := reflect.MakeSlice(vt, int(aryLen), int(aryLen)) for i := 0; i < int(aryLen); i++ { - value, err := d.readLong() + value, err := d.readInt64() if err != nil { return err } @@ -313,7 +313,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { if err != nil { return err } - listLen, err := d.readInt() + listLen, err := d.readInt32() if err != nil { return err } @@ -503,7 +503,7 @@ func (d *Decoder) rawRead(tagType byte) error { default: return fmt.Errorf("unknown to read %#02x", tagType) case TagByte: - _, err := d.readByte() + _, err := d.readInt8() return err case TagString: _, err := d.readString() @@ -518,7 +518,7 @@ func (d *Decoder) rawRead(tagType byte) error { _, err := io.ReadFull(d.r, buf[:8]) return err case TagByteArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } @@ -527,23 +527,23 @@ func (d *Decoder) rawRead(tagType byte) error { return err } case TagIntArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } for i := 0; i < int(aryLen); i++ { - if _, err := d.readInt(); err != nil { + if _, err := d.readInt32(); err != nil { return err } } case TagLongArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } for i := 0; i < int(aryLen); i++ { - if _, err := d.readLong(); err != nil { + if _, err := d.readInt64(); err != nil { return err } } @@ -553,7 +553,7 @@ func (d *Decoder) rawRead(tagType byte) error { if err != nil { return err } - listLen, err := d.readInt() + listLen, err := d.readInt32() if err != nil { return err } @@ -597,26 +597,26 @@ func (d *Decoder) readTag() (tagType byte, tagName string, err error) { return } -func (d *Decoder) readByte() (int8, error) { +func (d *Decoder) readInt8() (int8, error) { b, err := d.r.ReadByte() // TagByte is signed byte (that's what in Java), so we need to convert to int8 return int8(b), err } -func (d *Decoder) readShort() (int16, error) { +func (d *Decoder) readInt16() (int16, error) { var data [2]byte _, err := io.ReadFull(d.r, data[:]) return int16(data[0])<<8 | int16(data[1]), err } -func (d *Decoder) readInt() (int32, error) { +func (d *Decoder) readInt32() (int32, error) { var data [4]byte _, err := io.ReadFull(d.r, data[:]) return int32(data[0])<<24 | int32(data[1])<<16 | int32(data[2])<<8 | int32(data[3]), err } -func (d *Decoder) readLong() (int64, error) { +func (d *Decoder) readInt64() (int64, error) { var data [8]byte _, err := io.ReadFull(d.r, data[:]) return int64(data[0])<<56 | int64(data[1])<<48 | @@ -626,7 +626,7 @@ func (d *Decoder) readLong() (int64, error) { } func (d *Decoder) readString() (string, error) { - length, err := d.readShort() + length, err := d.readInt16() if err != nil { return "", err } else if length < 0 { diff --git a/nbt/snbt.go b/nbt/snbt.go index b7296f3..6adf2ab 100644 --- a/nbt/snbt.go +++ b/nbt/snbt.go @@ -89,29 +89,29 @@ func (m *StringifiedMessage) encode(d *Decoder, sb *strings.Builder, tagType byt writeEscapeStr(sb, str) return err case TagShort: - s, err := d.readShort() + s, err := d.readInt16() sb.WriteString(strconv.FormatInt(int64(s), 10) + "S") return err case TagInt: - i, err := d.readInt() + i, err := d.readInt32() sb.WriteString(strconv.FormatInt(int64(i), 10)) return err case TagFloat: - i, err := d.readInt() + i, err := d.readInt32() f := float64(math.Float32frombits(uint32(i))) sb.WriteString(strconv.FormatFloat(f, 'f', 10, 32) + "F") return err case TagLong: - i, err := d.readLong() + i, err := d.readInt64() sb.WriteString(strconv.FormatInt(i, 10) + "L") return err case TagDouble: - i, err := d.readLong() + i, err := d.readInt64() f := math.Float64frombits(uint64(i)) sb.WriteString(strconv.FormatFloat(f, 'f', 10, 64) + "D") return err case TagByteArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } @@ -131,14 +131,14 @@ func (m *StringifiedMessage) encode(d *Decoder, sb *strings.Builder, tagType byt } sb.WriteString("]") case TagIntArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } sb.WriteString("[I;") first := true for i := 0; i < int(aryLen); i++ { - v, err := d.readInt() + v, err := d.readInt32() if err != nil { return err } @@ -151,14 +151,14 @@ func (m *StringifiedMessage) encode(d *Decoder, sb *strings.Builder, tagType byt } sb.WriteString("]") case TagLongArray: - aryLen, err := d.readInt() + aryLen, err := d.readInt32() if err != nil { return err } first := true sb.WriteString("[L;") for i := 0; i < int(aryLen); i++ { - v, err := d.readLong() + v, err := d.readInt64() if err != nil { return err } @@ -175,7 +175,7 @@ func (m *StringifiedMessage) encode(d *Decoder, sb *strings.Builder, tagType byt if err != nil { return err } - listLen, err := d.readInt() + listLen, err := d.readInt32() if err != nil { return err } diff --git a/nbt/snbt_decode.go b/nbt/snbt_decode.go index f4706c6..06962d0 100644 --- a/nbt/snbt_decode.go +++ b/nbt/snbt_decode.go @@ -50,12 +50,7 @@ func writeValue(e *Encoder, d *decodeState, writeTag bool, tagName string) error return writeCompoundPayload(e, d) case scanBeginList: - if writeTag { - if err := e.writeTag(TagList, tagName); err != nil { - return err - } - } - _, err := writeListOrArray(e, d) + _, err := writeListOrArray(e, d, writeTag, tagName) return err } } @@ -126,7 +121,7 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error { return err } - // Next token must be , or }. + // The next token must be , or }. if d.opcode == scanSkipSpace { d.scanWhile(scanSkipSpace) } @@ -144,9 +139,15 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error { return err } -func writeListOrArray(e *Encoder, d *decodeState) (tagType byte, err error) { +func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) (tagType byte, err error) { d.scanWhile(scanSkipSpace) if d.opcode == scanEndValue { // ']', empty TAG_List + if writeTag { + err = e.writeTag(TagList, tagName) + if err != nil { + return tagType, err + } + } err = e.writeListHeader(TagEnd, 0) d.scanNext() return TagList, err @@ -186,77 +187,24 @@ func writeListOrArray(e *Encoder, d *decodeState) (tagType byte, err error) { default: return TagList, d.error("unknown Array type") } - if d.opcode == scanSkipSpace { - d.scanWhile(scanSkipSpace) - } - d.scanWhile(scanSkipSpace) // ; - if d.opcode == scanEndValue { // ] - // empty array - if err = e.writeInt32(0); err != nil { - return - } - break - } - for { - if d.opcode == scanSkipSpace { - d.scanWhile(scanSkipSpace) - } - if d.opcode != scanBeginLiteral { - return tagType, d.error("not literal in Array") - } - start := d.readIndex() - - if d.scanWhile(scanContinue); d.opcode == scanError { - return tagType, d.error(d.scan.errContext) - } - literal := d.data[start:d.readIndex()] - var subType byte - var litVal any - subType, litVal, err = parseLiteral(literal) + if writeTag { + err = e.writeTag(tagType, tagName) if err != nil { return tagType, err } - if subType != elemType { - err = d.error("unexpected element type in TAG_Array") - return - } - switch elemType { - case TagByte: - _, err = e2.w.Write([]byte{byte(litVal.(int8))}) - case TagInt: - err = e2.writeInt32(litVal.(int32)) - case TagLong: - err = e2.writeInt64(litVal.(int64)) - } - if err != nil { - return - } - count++ - - if d.opcode == scanSkipSpace { - d.scanWhile(scanSkipSpace) - } - if d.opcode == scanError { - return tagType, d.error(d.scan.errContext) - } - if d.opcode == scanEndValue { // ] - break - } - if d.opcode != scanListValue { - panic(phasePanicMsg) - } - d.scanWhile(scanSkipSpace) // , } - - if err = e.writeInt32(int32(count)); err != nil { - return tagType, err - } - _, err = e.w.Write(buf.Bytes()) + err = writeArray(e, e2, d, elemType, &count, &buf) if err != nil { return tagType, err } break } + if writeTag { + err = e.writeTag(TagList, tagName) + if err != nil { + return tagType, err + } + } if d.opcode != scanListValue && d.opcode != scanEndValue { // TAG_List panic(phasePanicMsg) } @@ -314,7 +262,7 @@ func writeListOrArray(e *Encoder, d *decodeState) (tagType byte, err error) { if d.opcode != scanBeginList { return TagList, d.error("different TagType in List") } - elemType, err = writeListOrArray(e2, d) + elemType, err = writeListOrArray(e2, d, false, "") if err != nil { return tagType, err } @@ -387,6 +335,76 @@ func writeListOrArray(e *Encoder, d *decodeState) (tagType byte, err error) { return } +func writeArray(e, e2 *Encoder, d *decodeState, elemType byte, count *int, buf *bytes.Buffer) (err error) { + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + d.scanWhile(scanSkipSpace) // ; + if d.opcode == scanEndValue { // ] + // empty array + if err = e.writeInt32(0); err != nil { + return + } + return + } + for { + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanBeginLiteral { + return d.error("not literal in Array") + } + start := d.readIndex() + + if d.scanWhile(scanContinue); d.opcode == scanError { + return d.error(d.scan.errContext) + } + literal := d.data[start:d.readIndex()] + var subType byte + var litVal any + subType, litVal, err = parseLiteral(literal) + if err != nil { + return err + } + if subType != elemType { + err = d.error("unexpected element type in TAG_Array") + return + } + switch elemType { + case TagByte: + _, err = e2.w.Write([]byte{byte(litVal.(int8))}) + case TagInt: + err = e2.writeInt32(litVal.(int32)) + case TagLong: + err = e2.writeInt64(litVal.(int64)) + } + if err != nil { + return + } + *count++ + + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanError { + return d.error(d.scan.errContext) + } + if d.opcode == scanEndValue { // ] + break + } + if d.opcode != scanListValue { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) // , + } + + if err = e.writeInt32(int32(*count)); err != nil { + return err + } + _, err = e.w.Write(buf.Bytes()) + return +} + // readIndex returns the position of the last byte read. func (d *decodeState) readIndex() int { return d.off - 1 diff --git a/nbt/snbt_decode_test.go b/nbt/snbt_decode_test.go index 91bdbc7..39b3e67 100644 --- a/nbt/snbt_decode_test.go +++ b/nbt/snbt_decode_test.go @@ -44,6 +44,8 @@ var testCases = []testCase{ {`[I; 1, 2 ,3]`, TagIntArray, []byte{11, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3}}, {`[L;]`, TagLongArray, []byte{12, 0, 0, 0, 0, 0, 0}}, {`[ L; 1L,2L,3L]`, TagLongArray, []byte{12, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3}}, + {`{a:[B;]}`, TagCompound, []byte{10, 0, 0, 7, 0, 1, 'a', 0, 0, 0, 0, 0}}, + {`{a:[B;1b,2B,3B]}`, TagCompound, []byte{10, 0, 0, 7, 0, 1, 'a', 0, 0, 0, 3, 1, 2, 3, 0}}, {`{d:[]}`, TagCompound, []byte{10, 0, 0, 9, 0, 1, 'd', 0, 0, 0, 0, 0, 0}}, {`{e:[]}`, TagCompound, []byte{10, 0, 0, 9, 0, 1, 'e', 0, 0, 0, 0, 0, 0}}, diff --git a/nbt/snbt_scanner.go b/nbt/snbt_scanner.go index 0df85ce..da2b078 100644 --- a/nbt/snbt_scanner.go +++ b/nbt/snbt_scanner.go @@ -8,7 +8,7 @@ const ( scanBeginCompound // begin TAG_Compound (after left-brace ) scanBeginList // begin TAG_List (after left-bracket) scanListValue // just finished read list value (after comma) - scanListType // just finished read list type (after "B;" or "L;") + scanListType // just finished read list type (after "B;", "I;" or "L;") scanCompoundTagName // just finished read tag name (before colon) scanCompoundValue // just finished read value (after comma) scanSkipSpace // space byte; can skip; known to be last "continue" result