Better syntax error handling

This commit is contained in:
Tnze
2021-06-22 10:01:41 +08:00
parent d730153750
commit 723303ce8d
4 changed files with 115 additions and 27 deletions

View File

@ -2,7 +2,6 @@ package nbt
import ( import (
"bytes" "bytes"
"errors"
"math" "math"
"strconv" "strconv"
"strings" "strings"
@ -10,7 +9,7 @@ import (
type decodeState struct { type decodeState struct {
data []byte data []byte
off int // next read offset in data off int // next read Offset in data
opcode int // last read result opcode int // last read result
scan scanner scan scanner
} }
@ -26,18 +25,25 @@ func (e *Encoder) WriteSNBT(snbt string) error {
func writeValue(e *Encoder, d *decodeState, tagName string) error { func writeValue(e *Encoder, d *decodeState, tagName string) error {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
switch d.opcode { switch d.opcode {
case scanError:
return d.error(d.scan.errContext)
default: default:
panic(phasePanicMsg) panic(phasePanicMsg)
case scanBeginLiteral: case scanBeginLiteral:
start := d.readIndex() start := d.readIndex()
d.scanWhile(scanContinue) if d.scanWhile(scanContinue); d.opcode == scanError {
return d.error(d.scan.errContext)
}
literal := d.data[start:d.readIndex()] literal := d.data[start:d.readIndex()]
tagType, litVal := parseLiteral(literal) tagType, litVal := parseLiteral(literal)
e.writeTag(tagType, tagName) e.writeTag(tagType, tagName)
return writeLiteralPayload(e, litVal) return writeLiteralPayload(e, litVal)
case scanBeginCompound: case scanBeginCompound:
e.writeTag(TagCompound, tagName) e.writeTag(TagCompound, tagName)
return writeCompoundPayload(e, d) return writeCompoundPayload(e, d)
case scanBeginList: case scanBeginList:
_, err := writeListOrArray(e, d, true, tagName) _, err := writeListOrArray(e, d, true, tagName)
return err return err
@ -73,12 +79,17 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error {
if d.opcode == scanEndValue { if d.opcode == scanEndValue {
break break
} }
if d.opcode == scanError {
return d.error(d.scan.errContext)
}
if d.opcode != scanBeginLiteral { if d.opcode != scanBeginLiteral {
panic(phasePanicMsg) panic(phasePanicMsg)
} }
// read tag name // read tag name
start := d.readIndex() start := d.readIndex()
d.scanWhile(scanContinue) if d.scanWhile(scanContinue); d.opcode == scanError {
return d.error(d.scan.errContext)
}
var tagName string var tagName string
if tt, v := parseLiteral(d.data[start:d.readIndex()]); tt == TagString { if tt, v := parseLiteral(d.data[start:d.readIndex()]); tt == TagString {
tagName = v.(string) tagName = v.(string)
@ -89,6 +100,9 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error {
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode == scanError {
return d.error(d.scan.errContext)
}
if d.opcode != scanCompoundTagName { if d.opcode != scanCompoundTagName {
panic(phasePanicMsg) panic(phasePanicMsg)
} }
@ -98,6 +112,9 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error {
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode == scanError {
return d.error(d.scan.errContext)
}
if d.opcode == scanEndValue { if d.opcode == scanEndValue {
break break
} }
@ -126,11 +143,16 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
switch d.opcode { switch d.opcode {
case scanBeginLiteral: case scanBeginLiteral:
d.scanWhile(scanContinue) if d.scanWhile(scanContinue); d.opcode == scanError {
return TagList, d.error(d.scan.errContext)
}
literal := d.data[start:d.readIndex()] literal := d.data[start:d.readIndex()]
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
if d.opcode == scanListType { // TAG_X_Array if d.opcode == scanListType { // TAG_X_Array
var elemType byte var elemType byte
switch literal[0] { switch literal[0] {
@ -144,7 +166,7 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
tagType = TagLongArray tagType = TagLongArray
elemType = TagLong elemType = TagLong
default: default:
return TagList, errors.New("unknown Array type") return TagList, d.error("unknown Array type")
} }
if writeTag { if writeTag {
e.writeTag(tagType, tagName) e.writeTag(tagType, tagName)
@ -163,15 +185,17 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode != scanBeginLiteral { if d.opcode != scanBeginLiteral {
return tagType, errors.New("not literal in Array") return tagType, d.error("not literal in Array")
} }
start := d.readIndex() start := d.readIndex()
d.scanWhile(scanContinue) if d.scanWhile(scanContinue); d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
literal := d.data[start:d.readIndex()] literal := d.data[start:d.readIndex()]
tagType, litVal := parseLiteral(literal) tagType, litVal := parseLiteral(literal)
if tagType != elemType { if tagType != elemType {
return tagType, errors.New("unexpected element type in TAG_Array") return tagType, d.error("unexpected element type in TAG_Array")
} }
switch elemType { switch elemType {
case TagByte: case TagByte:
@ -186,6 +210,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
if d.opcode == scanEndValue { // ] if d.opcode == scanEndValue { // ]
break break
} }
@ -208,7 +235,7 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
tagType = t tagType = t
} }
if t != tagType { if t != tagType {
return TagList, errors.New("different TagType in List") return TagList, d.error("different TagType in List")
} }
writeLiteralPayload(e2, v) writeLiteralPayload(e2, v)
count++ count++
@ -217,6 +244,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
if d.opcode == scanEndValue { if d.opcode == scanEndValue {
break break
} }
@ -225,7 +255,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
} }
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
start = d.readIndex() start = d.readIndex()
d.scanWhile(scanContinue) if d.scanWhile(scanContinue); d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
literal = d.data[start:d.readIndex()] literal = d.data[start:d.readIndex()]
} }
e.writeListHeader(tagType, tagName, count, writeTag) e.writeListHeader(tagType, tagName, count, writeTag)
@ -237,7 +269,7 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode != scanBeginList { if d.opcode != scanBeginList {
return TagList, errors.New("different TagType in List") return TagList, d.error("different TagType in List")
} }
elemType, err = writeListOrArray(e2, d, false, "") elemType, err = writeListOrArray(e2, d, false, "")
if err != nil { if err != nil {
@ -247,6 +279,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
// ',' or ']' // ',' or ']'
if d.opcode == scanEndValue { if d.opcode == scanEndValue {
break break
@ -265,7 +300,7 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode != scanBeginCompound { if d.opcode != scanBeginCompound {
return TagList, errors.New("different TagType in List") return TagList, d.error("different TagType in List")
} }
writeCompoundPayload(e2, d) writeCompoundPayload(e2, d)
count++ count++
@ -276,6 +311,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
} }
if d.opcode == scanError {
return tagType, d.error(d.scan.errContext)
}
if d.opcode == scanEndValue { if d.opcode == scanEndValue {
break break
} }
@ -414,6 +452,10 @@ func parseLiteral(literal []byte) (byte, interface{}) {
panic(phasePanicMsg) panic(phasePanicMsg)
} }
func (d *decodeState) error(msg string) *SyntaxError {
return &SyntaxError{Message: msg, Offset: d.off}
}
func isIntegerType(c byte) bool { func isIntegerType(c byte) bool {
return isFloatType(c) || return isFloatType(c) ||
c == 'B' || c == 'b' || c == 'B' || c == 'b' ||
@ -424,3 +466,10 @@ func isIntegerType(c byte) bool {
func isFloatType(c byte) bool { func isFloatType(c byte) bool {
return c == 'F' || c == 'f' || c == 'D' || c == 'd' return c == 'F' || c == 'f' || c == 'D' || c == 'd'
} }
type SyntaxError struct {
Message string
Offset int
}
func (e *SyntaxError) Error() string { return e.Message }

View File

@ -2,6 +2,7 @@ package nbt
import ( import (
"bytes" "bytes"
"strings"
"testing" "testing"
) )
@ -83,3 +84,28 @@ func BenchmarkEncoder_WriteSNBT_bigTest(b *testing.B) {
buf.Reset() buf.Reset()
} }
} }
func Test_WriteSNBT_nestingList(t *testing.T) {
var buf bytes.Buffer
e := NewEncoder(&buf)
// Our maximum supported nesting depth is 10000.
// The nesting depth of 10001 is 10000
err := e.WriteSNBT(strings.Repeat("[", 10001) + strings.Repeat("]", 10001))
if err != nil {
t.Error(err)
}
// Following code should return error instant of panic.
buf.Reset()
err = e.WriteSNBT(strings.Repeat("[", 10002) + strings.Repeat("]", 10002))
if err == nil {
t.Error("Exceeded the maximum depth of support, but no error was reported")
}
// Panic test
buf.Reset()
err = e.WriteSNBT(strings.Repeat("[", 20000) + strings.Repeat("]", 20000))
if err == nil {
t.Error("Exceeded the maximum depth of support, but no error was reported")
}
}

View File

@ -1,8 +1,6 @@
package nbt package nbt
import ( import "strconv"
"errors"
)
const ( const (
scanContinue = iota // uninteresting byte scanContinue = iota // uninteresting byte
@ -35,7 +33,7 @@ const maxNestingDepth = 10000
type scanner struct { type scanner struct {
step func(s *scanner, c byte) int step func(s *scanner, c byte) int
parseState []int parseState []int
err error errContext string
endTop bool endTop bool
} }
@ -44,7 +42,7 @@ type scanner struct {
func (s *scanner) reset() { func (s *scanner) reset() {
s.step = stateBeginValue s.step = stateBeginValue
s.parseState = s.parseState[0:0] s.parseState = s.parseState[0:0]
s.err = nil s.errContext = ""
s.endTop = false s.endTop = false
} }
@ -74,7 +72,7 @@ func (s *scanner) popParseState() {
// eof tells the scanner that the end of input has been reached. // eof tells the scanner that the end of input has been reached.
// It returns a scan status just as s.step does. // It returns a scan status just as s.step does.
func (s *scanner) eof() int { func (s *scanner) eof() int {
if s.err != nil { if s.errContext != "" {
return scanError return scanError
} }
if s.endTop { if s.endTop {
@ -84,8 +82,8 @@ func (s *scanner) eof() int {
if s.endTop { if s.endTop {
return scanEnd return scanEnd
} }
if s.err == nil { if s.errContext == "" {
s.err = errors.New("unexpected end of JSON input") s.errContext = "unexpected end of JSON input"
} }
return scanError return scanError
} }
@ -381,13 +379,13 @@ func stateEndValue(s *scanner, c byte) int {
func (s *scanner) error(c byte, context string) int { func (s *scanner) error(c byte, context string) int {
s.step = stateError s.step = stateError
s.err = errors.New(context) s.errContext = "invalid character " + quoteChar(c) + " " + context
return scanError return scanError
} }
// stateError is the state after reaching a syntax error, // stateError is the state after reaching a syntax error,
// such as after reading `[1}` or `5.1.2`. // such as after reading `[1}` or `5.1.2`.
func stateError(s *scanner, c byte) int { func stateError(*scanner, byte) int {
return scanError return scanError
} }
@ -406,3 +404,18 @@ func isAllowedInUnquotedString(c byte) bool {
c >= 'A' && c <= 'Z' || c >= 'A' && c <= 'Z' ||
c >= 'a' && c <= 'z' c >= 'a' && c <= 'z'
} }
// quoteChar formats c as a quoted character literal
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {
return `'\''`
}
if c == '"' {
return `'"'`
}
// use quoted string with different quotation marks
s := strconv.Quote(string(c))
return "'" + s[1:len(s)-1] + "'"
}

View File

@ -35,7 +35,7 @@ func TestSNBT_number(t *testing.T) {
} }
for _, str := range goods { for _, str := range goods {
if scan(str) == false { if scan(str) == false {
t.Errorf("scan valid data %q error: %v", str, s.err) t.Errorf("scan valid data %q error: %v", str, s.errContext)
} }
} }
} }
@ -56,7 +56,7 @@ func TestSNBT_compound(t *testing.T) {
for i, c := range []byte(str) { for i, c := range []byte(str) {
res := s.step(&s, c) res := s.step(&s, c)
if res == scanError { if res == scanError {
t.Errorf("scan valid data %q error: %v at [%d]", str[:i], s.err, i) t.Errorf("scan valid data %q error: %v at [%d]", str[:i], s.errContext, i)
break break
} }
} }
@ -83,7 +83,7 @@ func TestSNBT_list(t *testing.T) {
} }
for _, str := range goods { for _, str := range goods {
if scan(str) == false { if scan(str) == false {
t.Errorf("scan valid data %q error: %v", str, s.err) t.Errorf("scan valid data %q error: %v", str, s.errContext)
} }
} }
} }
@ -95,7 +95,7 @@ func BenchmarkSNBT_bigTest(b *testing.B) {
for _, c := range []byte(bigTestSNBT) { for _, c := range []byte(bigTestSNBT) {
res := s.step(&s, c) res := s.step(&s, c)
if res == scanError { if res == scanError {
b.Errorf("scan valid data %q error: %v at [%d]", bigTestSNBT[:i], s.err, i) b.Errorf("scan valid data %q error: %v at [%d]", bigTestSNBT[:i], s.errContext, i)
break break
} }
} }