Better io error handling

This commit is contained in:
Tnze
2021-06-22 10:33:07 +08:00
parent 723303ce8d
commit 2ff54efb7d

View File

@ -37,11 +37,15 @@ func writeValue(e *Encoder, d *decodeState, tagName string) error {
} }
literal := d.data[start:d.readIndex()] literal := d.data[start:d.readIndex()]
tagType, litVal := parseLiteral(literal) tagType, litVal := parseLiteral(literal)
e.writeTag(tagType, tagName) if err := e.writeTag(tagType, tagName); err != nil {
return err
}
return writeLiteralPayload(e, litVal) return writeLiteralPayload(e, litVal)
case scanBeginCompound: case scanBeginCompound:
e.writeTag(TagCompound, tagName) if err := e.writeTag(TagCompound, tagName); err != nil {
return err
}
return writeCompoundPayload(e, d) return writeCompoundPayload(e, d)
case scanBeginList: case scanBeginList:
@ -50,26 +54,29 @@ func writeValue(e *Encoder, d *decodeState, tagName string) error {
} }
} }
func writeLiteralPayload(e *Encoder, v interface{}) error { func writeLiteralPayload(e *Encoder, v interface{}) (err error) {
switch v.(type) { switch v.(type) {
case string: case string:
str := v.(string) str := v.(string)
e.writeInt16(int16(len(str))) err = e.writeInt16(int16(len(str)))
e.w.Write([]byte(str)) if err != nil {
case int8: return
e.w.Write([]byte{byte(v.(int8))})
case int16:
e.writeInt16(v.(int16))
case int32:
e.writeInt32(v.(int32))
case int64:
e.writeInt64(v.(int64))
case float32:
e.writeInt32(int32(math.Float32bits(v.(float32))))
case float64:
e.writeInt64(int64(math.Float64bits(v.(float64))))
} }
return nil _, err = e.w.Write([]byte(str))
case int8:
_, err = e.w.Write([]byte{byte(v.(int8))})
case int16:
err = e.writeInt16(v.(int16))
case int32:
err = e.writeInt32(v.(int32))
case int64:
err = e.writeInt64(v.(int64))
case float32:
err = e.writeInt32(int32(math.Float32bits(v.(float32))))
case float64:
err = e.writeInt64(int64(math.Float64bits(v.(float64))))
}
return
} }
func writeCompoundPayload(e *Encoder, d *decodeState) error { func writeCompoundPayload(e *Encoder, d *decodeState) error {
@ -106,7 +113,10 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error {
if d.opcode != scanCompoundTagName { if d.opcode != scanCompoundTagName {
panic(phasePanicMsg) panic(phasePanicMsg)
} }
writeValue(e, d, tagName)
if err := writeValue(e, d, tagName); err != nil {
return err
}
// Next token must be , or }. // Next token must be , or }.
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
@ -122,16 +132,16 @@ func writeCompoundPayload(e *Encoder, d *decodeState) error {
panic(phasePanicMsg) panic(phasePanicMsg)
} }
} }
e.w.Write([]byte{TagEnd}) _, err := e.w.Write([]byte{TagEnd})
return nil return err
} }
func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) (tagType byte, err error) { func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string) (tagType byte, err error) {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
if d.opcode == scanEndValue { // ']', empty TAG_List if d.opcode == scanEndValue { // ']', empty TAG_List
e.writeListHeader(TagEnd, tagName, 0, writeTag) err = e.writeListHeader(TagEnd, tagName, 0, writeTag)
d.scanNext() d.scanNext()
return TagList, nil return TagList, err
} }
// We don't know the length of the List, // We don't know the length of the List,
@ -169,7 +179,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
return TagList, d.error("unknown Array type") return TagList, d.error("unknown Array type")
} }
if writeTag { if writeTag {
e.writeTag(tagType, tagName) if err = e.writeTag(tagType, tagName); err != nil {
return
}
} }
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
@ -177,7 +189,9 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
d.scanWhile(scanSkipSpace) // ; d.scanWhile(scanSkipSpace) // ;
if d.opcode == scanEndValue { // ] if d.opcode == scanEndValue { // ]
// empty array // empty array
e.writeInt32(0) if err = e.writeInt32(0); err != nil {
return
}
break break
} }
for { for {
@ -193,17 +207,21 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
return tagType, d.error(d.scan.errContext) return tagType, d.error(d.scan.errContext)
} }
literal := d.data[start:d.readIndex()] literal := d.data[start:d.readIndex()]
tagType, litVal := parseLiteral(literal) subType, litVal := parseLiteral(literal)
if tagType != elemType { if subType != elemType {
return tagType, d.error("unexpected element type in TAG_Array") err = d.error("unexpected element type in TAG_Array")
return
} }
switch elemType { switch elemType {
case TagByte: case TagByte:
e2.w.Write([]byte{byte(litVal.(int8))}) _, err = e2.w.Write([]byte{byte(litVal.(int8))})
case TagInt: case TagInt:
e2.writeInt32(litVal.(int32)) err = e2.writeInt32(litVal.(int32))
case TagLong: case TagLong:
e2.writeInt64(litVal.(int64)) err = e2.writeInt64(litVal.(int64))
}
if err != nil {
return
} }
count++ count++
@ -221,8 +239,14 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
} }
d.scanWhile(scanSkipSpace) // , d.scanWhile(scanSkipSpace) // ,
} }
e.writeInt32(int32(count))
e.w.Write(buf.Bytes()) if err = e.writeInt32(int32(count)); err != nil {
return tagType, err
}
_, err = e.w.Write(buf.Bytes())
if err != nil {
return tagType, err
}
break break
} }
if d.opcode != scanListValue { // TAG_List<TAG_String> if d.opcode != scanListValue { // TAG_List<TAG_String>
@ -237,7 +261,10 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
if t != tagType { if t != tagType {
return TagList, d.error("different TagType in List") return TagList, d.error("different TagType in List")
} }
writeLiteralPayload(e2, v) err = writeLiteralPayload(e2, v)
if err != nil {
return tagType, err
}
count++ count++
// read ',' or ']' // read ',' or ']'
@ -260,8 +287,13 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
} }
literal = d.data[start:d.readIndex()] literal = d.data[start:d.readIndex()]
} }
e.writeListHeader(tagType, tagName, count, writeTag)
e.w.Write(buf.Bytes()) if err := e.writeListHeader(tagType, tagName, count, writeTag); err != nil {
return tagType, err
}
if _, err := e.w.Write(buf.Bytes()); err != nil {
return tagType, err
}
case scanBeginList: // TAG_List<TAG_List> case scanBeginList: // TAG_List<TAG_List>
var elemType byte var elemType byte
for { for {
@ -292,8 +324,13 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
// read '[' // read '['
d.scanNext() d.scanNext()
} }
e.writeListHeader(elemType, tagName, count, writeTag)
e.w.Write(buf.Bytes()) if err = e.writeListHeader(elemType, tagName, count, writeTag); err != nil {
return
}
if _, err = e.w.Write(buf.Bytes()); err != nil {
return
}
case scanBeginCompound: // TAG_List<TAG_Compound> case scanBeginCompound: // TAG_List<TAG_Compound>
for { for {
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
@ -302,7 +339,10 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
if d.opcode != scanBeginCompound { if d.opcode != scanBeginCompound {
return TagList, d.error("different TagType in List") return TagList, d.error("different TagType in List")
} }
writeCompoundPayload(e2, d)
if err = writeCompoundPayload(e2, d); err != nil {
return
}
count++ count++
if d.opcode == scanSkipSpace { if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace) d.scanWhile(scanSkipSpace)
@ -323,8 +363,14 @@ func writeListOrArray(e *Encoder, d *decodeState, writeTag bool, tagName string)
// read '{' // read '{'
d.scanNext() d.scanNext()
} }
e.writeListHeader(TagCompound, tagName, count, writeTag)
e.w.Write(buf.Bytes()) if err = e.writeListHeader(TagCompound, tagName, count, writeTag); err != nil {
return
}
if _, err = e.w.Write(buf.Bytes()); err != nil {
return
}
} }
d.scanNext() d.scanNext()
return return