diff --git a/nbt/snbt_decode.go b/nbt/snbt_decode.go index 266f4e0..912df6b 100644 --- a/nbt/snbt_decode.go +++ b/nbt/snbt_decode.go @@ -2,7 +2,6 @@ package nbt import ( "bytes" - "errors" "math" "strconv" "strings" @@ -10,7 +9,7 @@ import ( type decodeState struct { data []byte - off int // next read offset in data + off int // next read Offset in data opcode int // last read result scan scanner } @@ -26,18 +25,25 @@ func (e *Encoder) WriteSNBT(snbt string) error { func writeValue(e *Encoder, d *decodeState, tagName string) error { d.scanWhile(scanSkipSpace) switch d.opcode { + case scanError: + return d.error(d.scan.errContext) default: panic(phasePanicMsg) + case scanBeginLiteral: start := d.readIndex() - d.scanWhile(scanContinue) + if d.scanWhile(scanContinue); d.opcode == scanError { + return d.error(d.scan.errContext) + } literal := d.data[start:d.readIndex()] tagType, litVal := parseLiteral(literal) e.writeTag(tagType, tagName) return writeLiteralPayload(e, litVal) + case scanBeginCompound: e.writeTag(TagCompound, tagName) return writeCompoundPayload(e, d) + case scanBeginList: _, err := writeListOrArray(e, d, true, tagName) return err @@ -73,12 +79,17 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error { if d.opcode == scanEndValue { break } + if d.opcode == scanError { + return d.error(d.scan.errContext) + } if d.opcode != scanBeginLiteral { panic(phasePanicMsg) } // read tag name start := d.readIndex() - d.scanWhile(scanContinue) + if d.scanWhile(scanContinue); d.opcode == scanError { + return d.error(d.scan.errContext) + } var tagName string if tt, v := parseLiteral(d.data[start:d.readIndex()]); tt == TagString { tagName = v.(string) @@ -89,6 +100,9 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error { if d.opcode == scanSkipSpace { d.scanWhile(scanSkipSpace) } + if d.opcode == scanError { + return d.error(d.scan.errContext) + } if d.opcode != scanCompoundTagName { panic(phasePanicMsg) } @@ -98,6 +112,9 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error { if d.opcode == scanSkipSpace { d.scanWhile(scanSkipSpace) } + if d.opcode == scanError { + return d.error(d.scan.errContext) + } if d.opcode == scanEndValue { break } @@ -126,11 +143,16 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) switch d.opcode { case scanBeginLiteral: - d.scanWhile(scanContinue) + if d.scanWhile(scanContinue); d.opcode == scanError { + return TagList, d.error(d.scan.errContext) + } literal := d.data[start:d.readIndex()] if d.opcode == scanSkipSpace { d.scanWhile(scanSkipSpace) } + if d.opcode == scanError { + return tagType, d.error(d.scan.errContext) + } if d.opcode == scanListType { // TAG_X_Array var elemType byte switch literal[0] { @@ -144,7 +166,7 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) tagType = TagLongArray elemType = TagLong default: - return TagList, errors.New("unknown Array type") + return TagList, d.error("unknown Array type") } if writeTag { e.writeTag(tagType, tagName) @@ -163,15 +185,17 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) d.scanWhile(scanSkipSpace) } if d.opcode != scanBeginLiteral { - return tagType, errors.New("not literal in Array") + return tagType, d.error("not literal in Array") } start := d.readIndex() - d.scanWhile(scanContinue) + if d.scanWhile(scanContinue); d.opcode == scanError { + return tagType, d.error(d.scan.errContext) + } literal := d.data[start:d.readIndex()] tagType, litVal := parseLiteral(literal) if tagType != elemType { - return tagType, errors.New("unexpected element type in TAG_Array") + return tagType, d.error("unexpected element type in TAG_Array") } switch elemType { case TagByte: @@ -186,6 +210,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) if d.opcode == scanSkipSpace { d.scanWhile(scanSkipSpace) } + if d.opcode == scanError { + return tagType, d.error(d.scan.errContext) + } if d.opcode == scanEndValue { // ] break } @@ -208,7 +235,7 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) tagType = t } if t != tagType { - return TagList, errors.New("different TagType in List") + return TagList, d.error("different TagType in List") } writeLiteralPayload(e2, v) count++ @@ -217,6 +244,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) if d.opcode == scanSkipSpace { d.scanWhile(scanSkipSpace) } + if d.opcode == scanError { + return tagType, d.error(d.scan.errContext) + } if d.opcode == scanEndValue { break } @@ -225,7 +255,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) } d.scanWhile(scanSkipSpace) start = d.readIndex() - d.scanWhile(scanContinue) + if d.scanWhile(scanContinue); d.opcode == scanError { + return tagType, d.error(d.scan.errContext) + } literal = d.data[start:d.readIndex()] } e.writeListHeader(tagType, tagName, count, writeTag) @@ -237,7 +269,7 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) d.scanWhile(scanSkipSpace) } if d.opcode != scanBeginList { - return TagList, errors.New("different TagType in List") + return TagList, d.error("different TagType in List") } elemType, err = writeListOrArray(e2, d, false, "") if err != nil { @@ -247,6 +279,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) if d.opcode == scanSkipSpace { d.scanWhile(scanSkipSpace) } + if d.opcode == scanError { + return tagType, d.error(d.scan.errContext) + } // ',' or ']' if d.opcode == scanEndValue { break @@ -265,7 +300,7 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) d.scanWhile(scanSkipSpace) } if d.opcode != scanBeginCompound { - return TagList, errors.New("different TagType in List") + return TagList, d.error("different TagType in List") } writeCompoundPayload(e2, d) count++ @@ -276,6 +311,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) if d.opcode == scanSkipSpace { d.scanWhile(scanSkipSpace) } + if d.opcode == scanError { + return tagType, d.error(d.scan.errContext) + } if d.opcode == scanEndValue { break } @@ -414,6 +452,10 @@ func parseLiteral(literal []byte) (byte, interface{}) { panic(phasePanicMsg) } +func (d *decodeState) error(msg string) *SyntaxError { + return &SyntaxError{Message: msg, Offset: d.off} +} + func isIntegerType(c byte) bool { return isFloatType(c) || c == 'B' || c == 'b' || @@ -424,3 +466,10 @@ func isIntegerType(c byte) bool { func isFloatType(c byte) bool { return c == 'F' || c == 'f' || c == 'D' || c == 'd' } + +type SyntaxError struct { + Message string + Offset int +} + +func (e *SyntaxError) Error() string { return e.Message } diff --git a/nbt/snbt_decode_test.go b/nbt/snbt_decode_test.go index 9f74560..a683d18 100644 --- a/nbt/snbt_decode_test.go +++ b/nbt/snbt_decode_test.go @@ -2,6 +2,7 @@ package nbt import ( "bytes" + "strings" "testing" ) @@ -83,3 +84,28 @@ func BenchmarkEncoder_WriteSNBT_bigTest(b *testing.B) { buf.Reset() } } + +func Test_WriteSNBT_nestingList(t *testing.T) { + var buf bytes.Buffer + e := NewEncoder(&buf) + + // Our maximum supported nesting depth is 10000. + // The nesting depth of 10001 is 10000 + err := e.WriteSNBT(strings.Repeat("[", 10001) + strings.Repeat("]", 10001)) + if err != nil { + t.Error(err) + } + + // Following code should return error instant of panic. + buf.Reset() + err = e.WriteSNBT(strings.Repeat("[", 10002) + strings.Repeat("]", 10002)) + if err == nil { + t.Error("Exceeded the maximum depth of support, but no error was reported") + } + // Panic test + buf.Reset() + err = e.WriteSNBT(strings.Repeat("[", 20000) + strings.Repeat("]", 20000)) + if err == nil { + t.Error("Exceeded the maximum depth of support, but no error was reported") + } +} diff --git a/nbt/snbt_scanner.go b/nbt/snbt_scanner.go index d057794..17c5723 100644 --- a/nbt/snbt_scanner.go +++ b/nbt/snbt_scanner.go @@ -1,8 +1,6 @@ package nbt -import ( - "errors" -) +import "strconv" const ( scanContinue = iota // uninteresting byte @@ -35,7 +33,7 @@ const maxNestingDepth = 10000 type scanner struct { step func(s *scanner, c byte) int parseState []int - err error + errContext string endTop bool } @@ -44,7 +42,7 @@ type scanner struct { func (s *scanner) reset() { s.step = stateBeginValue s.parseState = s.parseState[0:0] - s.err = nil + s.errContext = "" s.endTop = false } @@ -74,7 +72,7 @@ func (s *scanner) popParseState() { // eof tells the scanner that the end of input has been reached. // It returns a scan status just as s.step does. func (s *scanner) eof() int { - if s.err != nil { + if s.errContext != "" { return scanError } if s.endTop { @@ -84,8 +82,8 @@ func (s *scanner) eof() int { if s.endTop { return scanEnd } - if s.err == nil { - s.err = errors.New("unexpected end of JSON input") + if s.errContext == "" { + s.errContext = "unexpected end of JSON input" } return scanError } @@ -381,13 +379,13 @@ func stateEndValue(s *scanner, c byte) int { func (s *scanner) error(c byte, context string) int { s.step = stateError - s.err = errors.New(context) + s.errContext = "invalid character " + quoteChar(c) + " " + context return scanError } // stateError is the state after reaching a syntax error, // such as after reading `[1}` or `5.1.2`. -func stateError(s *scanner, c byte) int { +func stateError(*scanner, byte) int { return scanError } @@ -406,3 +404,18 @@ func isAllowedInUnquotedString(c byte) bool { c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' } + +// quoteChar formats c as a quoted character literal +func quoteChar(c byte) string { + // special cases - different from quoted strings + if c == '\'' { + return `'\''` + } + if c == '"' { + return `'"'` + } + + // use quoted string with different quotation marks + s := strconv.Quote(string(c)) + return "'" + s[1:len(s)-1] + "'" +} diff --git a/nbt/snbt_scanner_test.go b/nbt/snbt_scanner_test.go index 2036564..9bf7b37 100644 --- a/nbt/snbt_scanner_test.go +++ b/nbt/snbt_scanner_test.go @@ -35,7 +35,7 @@ func TestSNBT_number(t *testing.T) { } for _, str := range goods { if scan(str) == false { - t.Errorf("scan valid data %q error: %v", str, s.err) + t.Errorf("scan valid data %q error: %v", str, s.errContext) } } } @@ -56,7 +56,7 @@ func TestSNBT_compound(t *testing.T) { for i, c := range []byte(str) { res := s.step(&s, c) if res == scanError { - t.Errorf("scan valid data %q error: %v at [%d]", str[:i], s.err, i) + t.Errorf("scan valid data %q error: %v at [%d]", str[:i], s.errContext, i) break } } @@ -83,7 +83,7 @@ func TestSNBT_list(t *testing.T) { } for _, str := range goods { if scan(str) == false { - t.Errorf("scan valid data %q error: %v", str, s.err) + t.Errorf("scan valid data %q error: %v", str, s.errContext) } } } @@ -95,7 +95,7 @@ func BenchmarkSNBT_bigTest(b *testing.B) { for _, c := range []byte(bigTestSNBT) { res := s.step(&s, c) if res == scanError { - b.Errorf("scan valid data %q error: %v at [%d]", bigTestSNBT[:i], s.err, i) + b.Errorf("scan valid data %q error: %v at [%d]", bigTestSNBT[:i], s.errContext, i) break } }