Fix the issue of snbt decoding when Tag*Arrays in TagCompound

This commit is contained in:
Tnze
2023-04-24 01:04:21 +08:00
parent a511ad3d2a
commit cbf5a7c053
5 changed files with 128 additions and 108 deletions

View File

@ -79,7 +79,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
return ErrEND
case TagByte:
value, err := d.readByte()
value, err := d.readInt8()
if err != nil {
return err
}
@ -97,7 +97,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
}
case TagShort:
value, err := d.readShort()
value, err := d.readInt16()
if err != nil {
return err
}
@ -113,7 +113,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
}
case TagInt:
value, err := d.readInt()
value, err := d.readInt32()
if err != nil {
return err
}
@ -129,7 +129,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
}
case TagFloat:
vInt, err := d.readInt()
vInt, err := d.readInt32()
if err != nil {
return err
}
@ -146,7 +146,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
}
case TagLong:
value, err := d.readLong()
value, err := d.readInt64()
if err != nil {
return err
}
@ -162,7 +162,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
}
case TagDouble:
vInt, err := d.readLong()
vInt, err := d.readInt64()
if err != nil {
return err
}
@ -199,7 +199,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
}
case TagByteArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
@ -242,7 +242,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
}
case TagIntArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
@ -262,7 +262,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
buf = reflect.MakeSlice(vt, int(aryLen), int(aryLen))
}
for i := 0; i < int(aryLen); i++ {
value, err := d.readInt()
value, err := d.readInt32()
if err != nil {
return err
}
@ -273,7 +273,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
}
case TagLongArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
@ -287,7 +287,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
case reflect.Int64:
buf := reflect.MakeSlice(vt, int(aryLen), int(aryLen))
for i := 0; i < int(aryLen); i++ {
value, err := d.readLong()
value, err := d.readInt64()
if err != nil {
return err
}
@ -297,7 +297,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
case reflect.Uint64:
buf := reflect.MakeSlice(vt, int(aryLen), int(aryLen))
for i := 0; i < int(aryLen); i++ {
value, err := d.readLong()
value, err := d.readInt64()
if err != nil {
return err
}
@ -313,7 +313,7 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error {
if err != nil {
return err
}
listLen, err := d.readInt()
listLen, err := d.readInt32()
if err != nil {
return err
}
@ -503,7 +503,7 @@ func (d *Decoder) rawRead(tagType byte) error {
default:
return fmt.Errorf("unknown to read %#02x", tagType)
case TagByte:
_, err := d.readByte()
_, err := d.readInt8()
return err
case TagString:
_, err := d.readString()
@ -518,7 +518,7 @@ func (d *Decoder) rawRead(tagType byte) error {
_, err := io.ReadFull(d.r, buf[:8])
return err
case TagByteArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
@ -527,23 +527,23 @@ func (d *Decoder) rawRead(tagType byte) error {
return err
}
case TagIntArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
for i := 0; i < int(aryLen); i++ {
if _, err := d.readInt(); err != nil {
if _, err := d.readInt32(); err != nil {
return err
}
}
case TagLongArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
for i := 0; i < int(aryLen); i++ {
if _, err := d.readLong(); err != nil {
if _, err := d.readInt64(); err != nil {
return err
}
}
@ -553,7 +553,7 @@ func (d *Decoder) rawRead(tagType byte) error {
if err != nil {
return err
}
listLen, err := d.readInt()
listLen, err := d.readInt32()
if err != nil {
return err
}
@ -597,26 +597,26 @@ func (d *Decoder) readTag() (tagType byte, tagName string, err error) {
return
}
func (d *Decoder) readByte() (int8, error) {
func (d *Decoder) readInt8() (int8, error) {
b, err := d.r.ReadByte()
// TagByte is signed byte (that's what in Java), so we need to convert to int8
return int8(b), err
}
func (d *Decoder) readShort() (int16, error) {
func (d *Decoder) readInt16() (int16, error) {
var data [2]byte
_, err := io.ReadFull(d.r, data[:])
return int16(data[0])<<8 | int16(data[1]), err
}
func (d *Decoder) readInt() (int32, error) {
func (d *Decoder) readInt32() (int32, error) {
var data [4]byte
_, err := io.ReadFull(d.r, data[:])
return int32(data[0])<<24 | int32(data[1])<<16 |
int32(data[2])<<8 | int32(data[3]), err
}
func (d *Decoder) readLong() (int64, error) {
func (d *Decoder) readInt64() (int64, error) {
var data [8]byte
_, err := io.ReadFull(d.r, data[:])
return int64(data[0])<<56 | int64(data[1])<<48 |
@ -626,7 +626,7 @@ func (d *Decoder) readLong() (int64, error) {
}
func (d *Decoder) readString() (string, error) {
length, err := d.readShort()
length, err := d.readInt16()
if err != nil {
return "", err
} else if length < 0 {

View File

@ -89,29 +89,29 @@ func (m *StringifiedMessage) encode(d *Decoder, sb *strings.Builder, tagType byt
writeEscapeStr(sb, str)
return err
case TagShort:
s, err := d.readShort()
s, err := d.readInt16()
sb.WriteString(strconv.FormatInt(int64(s), 10) + "S")
return err
case TagInt:
i, err := d.readInt()
i, err := d.readInt32()
sb.WriteString(strconv.FormatInt(int64(i), 10))
return err
case TagFloat:
i, err := d.readInt()
i, err := d.readInt32()
f := float64(math.Float32frombits(uint32(i)))
sb.WriteString(strconv.FormatFloat(f, 'f', 10, 32) + "F")
return err
case TagLong:
i, err := d.readLong()
i, err := d.readInt64()
sb.WriteString(strconv.FormatInt(i, 10) + "L")
return err
case TagDouble:
i, err := d.readLong()
i, err := d.readInt64()
f := math.Float64frombits(uint64(i))
sb.WriteString(strconv.FormatFloat(f, 'f', 10, 64) + "D")
return err
case TagByteArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
@ -131,14 +131,14 @@ func (m *StringifiedMessage) encode(d *Decoder, sb *strings.Builder, tagType byt
}
sb.WriteString("]")
case TagIntArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
sb.WriteString("[I;")
first := true
for i := 0; i < int(aryLen); i++ {
v, err := d.readInt()
v, err := d.readInt32()
if err != nil {
return err
}
@ -151,14 +151,14 @@ func (m *StringifiedMessage) encode(d *Decoder, sb *strings.Builder, tagType byt
}
sb.WriteString("]")
case TagLongArray:
aryLen, err := d.readInt()
aryLen, err := d.readInt32()
if err != nil {
return err
}
first := true
sb.WriteString("[L;")
for i := 0; i < int(aryLen); i++ {
v, err := d.readLong()
v, err := d.readInt64()
if err != nil {
return err
}
@ -175,7 +175,7 @@ func (m *StringifiedMessage) encode(d *Decoder, sb *strings.Builder, tagType byt
if err != nil {
return err
}
listLen, err := d.readInt()
listLen, err := d.readInt32()
if err != nil {
return err
}

View File

@ -50,12 +50,7 @@ func writeValue(e *Encoder, d *decodeState, writeTag bool, tagName string) error
return writeCompoundPayload(e, d)
case scanBeginList:
if writeTag {
if err := e.writeTag(TagList, tagName); err != nil {
return err
}
}
_, err := writeListOrArray(e, d)
_, err := writeListOrArray(e, d, writeTag, tagName)
return err
}
}
@ -126,7 +121,7 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error {
return err
}
// Next token must be , or }.
// The next token must be , or }.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
@ -144,9 +139,15 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error {
return err
}
func writeListOrArray(e *Encoder, d *decodeState) (tagType byte, err error) {
func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) (tagType byte, err error) {
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndValue { // ']', empty TAG_List
if writeTag {
err = e.writeTag(TagList, tagName)
if err != nil {
return tagType, err
}
}
err = e.writeListHeader(TagEnd, 0)
d.scanNext()
return TagList, err
@ -186,77 +187,24 @@ func writeListOrArray(e *Encoder, d *decodeState) (tagType byte, err error) {
default:
return TagList, d.error("unknown Array type")
}
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
d.scanWhile(scanSkipSpace) // ;
if d.opcode == scanEndValue { // ]
// empty array
if err = e.writeInt32(0); err != nil {
return
}
break
}
for {
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode != scanBeginLiteral {
return tagType, d.error("not literal in Array")
}
start := d.readIndex()
if d.scanWhile(scanContinue); d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
literal := d.data[start:d.readIndex()]
var subType byte
var litVal any
subType, litVal, err = parseLiteral(literal)
if writeTag {
err = e.writeTag(tagType, tagName)
if err != nil {
return tagType, err
}
if subType != elemType {
err = d.error("unexpected element type in TAG_Array")
return
}
switch elemType {
case TagByte:
_, err = e2.w.Write([]byte{byte(litVal.(int8))})
case TagInt:
err = e2.writeInt32(litVal.(int32))
case TagLong:
err = e2.writeInt64(litVal.(int64))
}
if err != nil {
return
}
count++
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
if d.opcode == scanEndValue { // ]
break
}
if d.opcode != scanListValue {
panic(phasePanicMsg)
}
d.scanWhile(scanSkipSpace) // ,
}
if err = e.writeInt32(int32(count)); err != nil {
return tagType, err
}
_, err = e.w.Write(buf.Bytes())
err = writeArray(e, e2, d, elemType, &count, &buf)
if err != nil {
return tagType, err
}
break
}
if writeTag {
err = e.writeTag(TagList, tagName)
if err != nil {
return tagType, err
}
}
if d.opcode != scanListValue && d.opcode != scanEndValue { // TAG_List<TAG_String>
panic(phasePanicMsg)
}
@ -314,7 +262,7 @@ func writeListOrArray(e *Encoder, d *decodeState) (tagType byte, err error) {
if d.opcode != scanBeginList {
return TagList, d.error("different TagType in List")
}
elemType, err = writeListOrArray(e2, d)
elemType, err = writeListOrArray(e2, d, false, "")
if err != nil {
return tagType, err
}
@ -387,6 +335,76 @@ func writeListOrArray(e *Encoder, d *decodeState) (tagType byte, err error) {
return
}
func writeArray(e, e2 *Encoder, d *decodeState, elemType byte, count *int, buf *bytes.Buffer) (err error) {
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
d.scanWhile(scanSkipSpace) // ;
if d.opcode == scanEndValue { // ]
// empty array
if err = e.writeInt32(0); err != nil {
return
}
return
}
for {
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode != scanBeginLiteral {
return d.error("not literal in Array")
}
start := d.readIndex()
if d.scanWhile(scanContinue); d.opcode == scanError {
return d.error(d.scan.errContext)
}
literal := d.data[start:d.readIndex()]
var subType byte
var litVal any
subType, litVal, err = parseLiteral(literal)
if err != nil {
return err
}
if subType != elemType {
err = d.error("unexpected element type in TAG_Array")
return
}
switch elemType {
case TagByte:
_, err = e2.w.Write([]byte{byte(litVal.(int8))})
case TagInt:
err = e2.writeInt32(litVal.(int32))
case TagLong:
err = e2.writeInt64(litVal.(int64))
}
if err != nil {
return
}
*count++
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode == scanError {
return d.error(d.scan.errContext)
}
if d.opcode == scanEndValue { // ]
break
}
if d.opcode != scanListValue {
panic(phasePanicMsg)
}
d.scanWhile(scanSkipSpace) // ,
}
if err = e.writeInt32(int32(*count)); err != nil {
return err
}
_, err = e.w.Write(buf.Bytes())
return
}
// readIndex returns the position of the last byte read.
func (d *decodeState) readIndex() int {
return d.off - 1

View File

@ -44,6 +44,8 @@ var testCases = []testCase{
{`[I; 1, 2 ,3]`, TagIntArray, []byte{11, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3}},
{`[L;]`, TagLongArray, []byte{12, 0, 0, 0, 0, 0, 0}},
{`[ L; 1L,2L,3L]`, TagLongArray, []byte{12, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3}},
{`{a:[B;]}`, TagCompound, []byte{10, 0, 0, 7, 0, 1, 'a', 0, 0, 0, 0, 0}},
{`{a:[B;1b,2B,3B]}`, TagCompound, []byte{10, 0, 0, 7, 0, 1, 'a', 0, 0, 0, 3, 1, 2, 3, 0}},
{`{d:[]}`, TagCompound, []byte{10, 0, 0, 9, 0, 1, 'd', 0, 0, 0, 0, 0, 0}},
{`{e:[]}`, TagCompound, []byte{10, 0, 0, 9, 0, 1, 'e', 0, 0, 0, 0, 0, 0}},

View File

@ -8,7 +8,7 @@ const (
scanBeginCompound // begin TAG_Compound (after left-brace )
scanBeginList // begin TAG_List (after left-bracket)
scanListValue // just finished read list value (after comma)
scanListType // just finished read list type (after "B;" or "L;")
scanListType // just finished read list type (after "B;", "I;" or "L;")
scanCompoundTagName // just finished read tag name (before colon)
scanCompoundValue // just finished read value (after comma)
scanSkipSpace // space byte; can skip; known to be last "continue" result