add support for nonarray lists, fix bugs with structs, make compliant with bigtest.nbt

This commit is contained in:
Mark Asp
2020-09-17 22:38:48 -05:00
parent 610cb0c7d5
commit 95b9ba9360
5 changed files with 246 additions and 152 deletions

View File

@ -5,12 +5,26 @@ import (
"io"
"math"
"reflect"
"strings"
)
var (
ErrMustBeStruct = errors.New("a compound can only be a struct")
)
func Marshal(w io.Writer, v interface{}) error {
return NewEncoder(w).Encode(v)
}
func MarshalCompound(w io.Writer, v interface{}, rootTagName string) error {
enc := NewEncoder(w)
val := reflect.ValueOf(v)
if val.Kind() != reflect.Struct {
return ErrMustBeStruct
}
return enc.marshal(val, TagCompound, rootTagName)
}
type Encoder struct {
w io.Writer
}
@ -21,166 +35,86 @@ func NewEncoder(w io.Writer) *Encoder {
func (e *Encoder) Encode(v interface{}) error {
val := reflect.ValueOf(v)
return e.marshal(val, "")
return e.marshal(val, getTagType(val.Type()), "")
}
func (e *Encoder) marshal(val reflect.Value, tagName string) error {
switch vk := val.Kind(); vk {
default:
return errors.New("unknown type " + vk.String())
func (e *Encoder) marshal(val reflect.Value, tagType byte, tagName string) (err error) {
err = e.writeHeader(val, tagType, tagName)
err = e.writeValue(val, tagType)
return err
}
case reflect.Uint8:
if err := e.writeTag(TagByte, tagName); err != nil {
return err
}
func (e *Encoder) writeHeader(val reflect.Value, tagType byte, tagName string) (err error) {
if tagType == TagList {
eleType := getTagType(val.Type().Elem())
err = e.writeListHeader(eleType, tagName, val.Len())
} else {
err = e.writeTag(tagType, tagName)
}
return err
}
func (e *Encoder) writeValue(val reflect.Value, tagType byte) error {
switch tagType {
default:
return errors.New("unsupported type " + val.Type().Kind().String())
case TagByte:
_, err := e.w.Write([]byte{byte(val.Uint())})
return err
case reflect.Int16, reflect.Uint16:
if err := e.writeTag(TagShort, tagName); err != nil {
return err
}
case TagShort:
return e.writeInt16(int16(val.Int()))
case reflect.Int32, reflect.Uint32:
if err := e.writeTag(TagInt, tagName); err != nil {
return err
}
case TagInt:
return e.writeInt32(int32(val.Int()))
case reflect.Float32:
if err := e.writeTag(TagFloat, tagName); err != nil {
return err
}
case TagFloat:
return e.writeInt32(int32(math.Float32bits(float32(val.Float()))))
case reflect.Int64, reflect.Uint64:
if err := e.writeTag(TagLong, tagName); err != nil {
return err
}
return e.writeInt64(int64(val.Int()))
case reflect.Float64:
if err := e.writeTag(TagDouble, tagName); err != nil {
return err
}
case TagLong:
return e.writeInt64(val.Int())
case TagDouble:
return e.writeInt64(int64(math.Float64bits(val.Float())))
case reflect.Array, reflect.Slice:
case TagByteArray, TagIntArray, TagLongArray:
n := val.Len()
switch val.Type().Elem().Kind() {
case reflect.Uint8: // []byte
if err := e.writeTag(TagByteArray, tagName); err != nil {
return err
}
if err := e.writeInt32(int32(val.Len())); err != nil {
return err
}
if err := e.writeInt32(int32(n)); err != nil {
return err
}
if tagType == TagByteArray {
_, err := e.w.Write(val.Bytes())
return err
} else {
for i := 0; i < n; i++ {
v := val.Index(i).Int()
case reflect.Int32:
if err := e.writeTag(TagIntArray, tagName); err != nil {
return err
}
if err := e.writeInt32(int32(n)); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt32(int32(val.Index(i).Int())); err != nil {
return err
var err error
if tagType == TagIntArray {
err = e.writeInt32(int32(v))
} else if tagType == TagLongArray {
err = e.writeInt64(v)
}
}
case reflect.Int64:
if err := e.writeTag(TagLongArray, tagName); err != nil {
return err
}
if err := e.writeInt32(int32(n)); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt64(val.Index(i).Int()); err != nil {
return err
}
}
case reflect.Int16:
if err := e.writeListHeader(TagShort, tagName, val.Len()); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt16(int16(val.Index(i).Int())); err != nil {
return err
}
}
case reflect.Float32:
if err := e.writeListHeader(TagFloat, tagName, val.Len()); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt32(int32(math.Float32bits(float32(val.Index(i).Float())))); err != nil {
return err
}
}
case reflect.Float64:
if err := e.writeListHeader(TagDouble, tagName, val.Len()); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt64(int64(math.Float64bits(val.Index(i).Float()))); err != nil {
return err
}
}
case reflect.String:
if err := e.writeListHeader(TagString, tagName, n); err != nil {
return err
}
for i := 0; i < n; i++ {
// Write length of this string
s := val.Index(i).String()
if err := e.writeInt16(int16(len(s))); err != nil {
return err
}
// Write string
if _, err := e.w.Write([]byte(s)); err != nil {
return err
}
}
case reflect.Struct, reflect.Interface:
if err := e.writeListHeader(TagCompound, tagName, n); err != nil {
return err
}
for i := 0; i < n; i++ {
elemVal := val.Index(i)
if val.Type().Elem().Kind() == reflect.Interface {
elemVal = reflect.ValueOf(elemVal.Interface())
}
err := e.marshal(elemVal, "")
if err != nil {
return err
}
}
default:
return errors.New("unknown type " + val.Type().String() + " slice")
}
case reflect.String:
if err := e.writeTag(TagString, tagName); err != nil {
return err
case TagList:
for i := 0; i < val.Len(); i++ {
arrVal := val.Index(i)
err := e.writeValue(arrVal, getTagType(arrVal.Type()))
if err != nil {
return err
}
}
case TagString:
if err := e.writeInt16(int16(val.Len())); err != nil {
return err
}
_, err := e.w.Write([]byte(val.String()))
return err
case reflect.Struct:
if err := e.writeTag(TagCompound, ""); err != nil {
return err
case TagCompound:
if val.Kind() == reflect.Interface {
val = reflect.ValueOf(val.Interface())
}
n := val.NumField()
@ -191,12 +125,8 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error {
continue // Private field
}
tagName := f.Name
if tag != "" {
tagName = tag
}
err := e.marshal(val.Field(i), tagName)
tagProps := parseTag(f, tag)
err := e.marshal(val.Field(i), tagProps.Type, tagProps.Name)
if err != nil {
return err
}
@ -207,6 +137,65 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error {
return nil
}
func getTagType(vk reflect.Type) byte {
switch vk.Kind() {
case reflect.Uint8:
return TagByte
case reflect.Int16, reflect.Uint16:
return TagShort
case reflect.Int32, reflect.Uint32:
return TagInt
case reflect.Float32:
return TagFloat
case reflect.Int64, reflect.Uint64:
return TagLong
case reflect.Float64:
return TagDouble
case reflect.String:
return TagString
case reflect.Struct, reflect.Interface:
return TagCompound
case reflect.Array, reflect.Slice:
switch vk.Elem().Kind() {
case reflect.Uint8: // Special types for these values
return TagByteArray
case reflect.Int32:
return TagIntArray
case reflect.Int64:
return TagLongArray
default:
return TagList
}
default:
return TagNone
}
}
type tagProps struct {
Name string
Type byte
}
func parseTag(f reflect.StructField, tagName string) tagProps {
result := tagProps{}
result.Name = tagName
if result.Name == "" {
result.Name = f.Name
}
nbtType := f.Tag.Get("nbt_type")
result.Type = getTagType(f.Type)
if strings.Contains(nbtType, "noarray") {
if IsArrayTag(result.Type) {
result.Type = TagList // for expanding the array to a standard list
} else {
panic("noarray is only supported for array types (byte, int, long)")
}
}
return result
}
func (e *Encoder) writeTag(tagType byte, tagName string) error {
if _, err := e.w.Write([]byte{tagType}); err != nil {
return err
@ -233,11 +222,6 @@ func (e *Encoder) writeListHeader(elementType byte, tagName string, n int) (err
return nil
}
func (e *Encoder) writeNamelessTag(tagType byte, tagName string) error {
_, err := e.w.Write([]byte{tagType})
return err
}
func (e *Encoder) writeInt16(n int16) error {
_, err := e.w.Write([]byte{byte(n >> 8), byte(n)})
return err