Merge pull request #81 from masp/master

NBT: Add support for nonarray lists, fix bugs with structs, make compliant with bigtest.nbt
This commit is contained in:
Tnze
2020-09-22 01:07:32 +08:00
committed by GitHub
4 changed files with 299 additions and 258 deletions

View File

@ -5,12 +5,26 @@ import (
"io" "io"
"math" "math"
"reflect" "reflect"
"strings"
)
var (
ErrMustBeStruct = errors.New("a compound can only be a struct")
) )
func Marshal(w io.Writer, v interface{}) error { func Marshal(w io.Writer, v interface{}) error {
return NewEncoder(w).Encode(v) return NewEncoder(w).Encode(v)
} }
func MarshalCompound(w io.Writer, v interface{}, rootTagName string) error {
enc := NewEncoder(w)
val := reflect.ValueOf(v)
if val.Kind() != reflect.Struct {
return ErrMustBeStruct
}
return enc.marshal(val, TagCompound, rootTagName)
}
type Encoder struct { type Encoder struct {
w io.Writer w io.Writer
} }
@ -21,166 +35,87 @@ func NewEncoder(w io.Writer) *Encoder {
func (e *Encoder) Encode(v interface{}) error { func (e *Encoder) Encode(v interface{}) error {
val := reflect.ValueOf(v) val := reflect.ValueOf(v)
return e.marshal(val, "") return e.marshal(val, getTagType(val.Type()), "")
} }
func (e *Encoder) marshal(val reflect.Value, tagName string) error { func (e *Encoder) marshal(val reflect.Value, tagType byte, tagName string) error {
switch vk := val.Kind(); vk { if err := e.writeHeader(val, tagType, tagName); err != nil {
default:
return errors.New("unknown type " + vk.String())
case reflect.Uint8:
if err := e.writeTag(TagByte, tagName); err != nil {
return err return err
} }
return e.writeValue(val, tagType)
}
func (e *Encoder) writeHeader(val reflect.Value, tagType byte, tagName string) (err error) {
if tagType == TagList {
eleType := getTagType(val.Type().Elem())
err = e.writeListHeader(eleType, tagName, val.Len())
} else {
err = e.writeTag(tagType, tagName)
}
return err
}
func (e *Encoder) writeValue(val reflect.Value, tagType byte) error {
switch tagType {
default:
return errors.New("unsupported type " + val.Type().Kind().String())
case TagByte:
_, err := e.w.Write([]byte{byte(val.Uint())}) _, err := e.w.Write([]byte{byte(val.Uint())})
return err return err
case TagShort:
case reflect.Int16, reflect.Uint16:
if err := e.writeTag(TagShort, tagName); err != nil {
return err
}
return e.writeInt16(int16(val.Int())) return e.writeInt16(int16(val.Int()))
case TagInt:
case reflect.Int32, reflect.Uint32:
if err := e.writeTag(TagInt, tagName); err != nil {
return err
}
return e.writeInt32(int32(val.Int())) return e.writeInt32(int32(val.Int()))
case TagFloat:
case reflect.Float32:
if err := e.writeTag(TagFloat, tagName); err != nil {
return err
}
return e.writeInt32(int32(math.Float32bits(float32(val.Float())))) return e.writeInt32(int32(math.Float32bits(float32(val.Float()))))
case TagLong:
case reflect.Int64, reflect.Uint64: return e.writeInt64(val.Int())
if err := e.writeTag(TagLong, tagName); err != nil { case TagDouble:
return err
}
return e.writeInt64(int64(val.Int()))
case reflect.Float64:
if err := e.writeTag(TagDouble, tagName); err != nil {
return err
}
return e.writeInt64(int64(math.Float64bits(val.Float()))) return e.writeInt64(int64(math.Float64bits(val.Float())))
case TagByteArray, TagIntArray, TagLongArray:
case reflect.Array, reflect.Slice:
n := val.Len() n := val.Len()
switch val.Type().Elem().Kind() { if err := e.writeInt32(int32(n)); err != nil {
case reflect.Uint8: // []byte
if err := e.writeTag(TagByteArray, tagName); err != nil {
return err
}
if err := e.writeInt32(int32(val.Len())); err != nil {
return err return err
} }
if tagType == TagByteArray {
_, err := e.w.Write(val.Bytes()) _, err := e.w.Write(val.Bytes())
return err return err
} else {
for i := 0; i < n; i++ {
v := val.Index(i).Int()
case reflect.Int32: var err error
if err := e.writeTag(TagIntArray, tagName); err != nil { if tagType == TagIntArray {
return err err = e.writeInt32(int32(v))
} else if tagType == TagLongArray {
err = e.writeInt64(v)
} }
if err := e.writeInt32(int32(n)); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt32(int32(val.Index(i).Int())); err != nil {
return err
}
}
case reflect.Int64:
if err := e.writeTag(TagLongArray, tagName); err != nil {
return err
}
if err := e.writeInt32(int32(n)); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt64(val.Index(i).Int()); err != nil {
return err
}
}
case reflect.Int16:
if err := e.writeListHeader(TagShort, tagName, val.Len()); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt16(int16(val.Index(i).Int())); err != nil {
return err
}
}
case reflect.Float32:
if err := e.writeListHeader(TagFloat, tagName, val.Len()); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt32(int32(math.Float32bits(float32(val.Index(i).Float())))); err != nil {
return err
}
}
case reflect.Float64:
if err := e.writeListHeader(TagDouble, tagName, val.Len()); err != nil {
return err
}
for i := 0; i < n; i++ {
if err := e.writeInt64(int64(math.Float64bits(val.Index(i).Float()))); err != nil {
return err
}
}
case reflect.String:
if err := e.writeListHeader(TagString, tagName, n); err != nil {
return err
}
for i := 0; i < n; i++ {
// Write length of this string
s := val.Index(i).String()
if err := e.writeInt16(int16(len(s))); err != nil {
return err
}
// Write string
if _, err := e.w.Write([]byte(s)); err != nil {
return err
}
}
case reflect.Struct, reflect.Interface:
if err := e.writeListHeader(TagCompound, tagName, n); err != nil {
return err
}
for i := 0; i < n; i++ {
elemVal := val.Index(i)
if val.Type().Elem().Kind() == reflect.Interface {
elemVal = reflect.ValueOf(elemVal.Interface())
}
err := e.marshal(elemVal, "")
if err != nil { if err != nil {
return err return err
} }
} }
default:
return errors.New("unknown type " + val.Type().String() + " slice")
} }
case reflect.String: case TagList:
if err := e.writeTag(TagString, tagName); err != nil { for i := 0; i < val.Len(); i++ {
arrVal := val.Index(i)
err := e.writeValue(arrVal, getTagType(arrVal.Type()))
if err != nil {
return err return err
} }
}
case TagString:
if err := e.writeInt16(int16(val.Len())); err != nil { if err := e.writeInt16(int16(val.Len())); err != nil {
return err return err
} }
_, err := e.w.Write([]byte(val.String())) _, err := e.w.Write([]byte(val.String()))
return err return err
case reflect.Struct: case TagCompound:
if err := e.writeTag(TagCompound, ""); err != nil { if val.Kind() == reflect.Interface {
return err val = reflect.ValueOf(val.Interface())
} }
n := val.NumField() n := val.NumField()
@ -191,12 +126,8 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error {
continue // Private field continue // Private field
} }
tagName := f.Name tagProps := parseTag(f, tag)
if tag != "" { err := e.marshal(val.Field(i), tagProps.Type, tagProps.Name)
tagName = tag
}
err := e.marshal(val.Field(i), tagName)
if err != nil { if err != nil {
return err return err
} }
@ -207,6 +138,65 @@ func (e *Encoder) marshal(val reflect.Value, tagName string) error {
return nil return nil
} }
func getTagType(vk reflect.Type) byte {
switch vk.Kind() {
case reflect.Uint8:
return TagByte
case reflect.Int16, reflect.Uint16:
return TagShort
case reflect.Int32, reflect.Uint32:
return TagInt
case reflect.Float32:
return TagFloat
case reflect.Int64, reflect.Uint64:
return TagLong
case reflect.Float64:
return TagDouble
case reflect.String:
return TagString
case reflect.Struct, reflect.Interface:
return TagCompound
case reflect.Array, reflect.Slice:
switch vk.Elem().Kind() {
case reflect.Uint8: // Special types for these values
return TagByteArray
case reflect.Int32:
return TagIntArray
case reflect.Int64:
return TagLongArray
default:
return TagList
}
default:
return TagNone
}
}
type tagProps struct {
Name string
Type byte
}
func parseTag(f reflect.StructField, tagName string) tagProps {
result := tagProps{}
result.Name = tagName
if result.Name == "" {
result.Name = f.Name
}
nbtType := f.Tag.Get("nbt_type")
result.Type = getTagType(f.Type)
if strings.Contains(nbtType, "noarray") {
if IsArrayTag(result.Type) {
result.Type = TagList // for expanding the array to a standard list
} else {
panic("noarray is only supported for array types (byte, int, long)")
}
}
return result
}
func (e *Encoder) writeTag(tagType byte, tagName string) error { func (e *Encoder) writeTag(tagType byte, tagName string) error {
if _, err := e.w.Write([]byte{tagType}); err != nil { if _, err := e.w.Write([]byte{tagType}); err != nil {
return err return err
@ -233,11 +223,6 @@ func (e *Encoder) writeListHeader(elementType byte, tagName string, n int) (err
return nil return nil
} }
func (e *Encoder) writeNamelessTag(tagType byte, tagName string) error {
_, err := e.w.Write([]byte{tagType})
return err
}
func (e *Encoder) writeInt16(n int16) error { func (e *Encoder) writeInt16(n int16) error {
_, err := e.w.Write([]byte{byte(n >> 8), byte(n)}) _, err := e.w.Write([]byte{byte(n >> 8), byte(n)})
return err return err

View File

@ -56,6 +56,19 @@ func TestMarshal_FloatArray(t *testing.T) {
} }
} }
func TestMarshal_String(t *testing.T) {
v := "Test"
out := []byte{TagString, 0x00, 0x00, 0, 4,
'T', 'e', 's', 't'}
var buf bytes.Buffer
if err := Marshal(&buf, v); err != nil {
t.Error(err)
} else if !bytes.Equal(buf.Bytes(), out) {
t.Errorf("output binary not right: got % 02x, want % 02x ", buf.Bytes(), out)
}
}
func TestMarshal_InterfaceArray(t *testing.T) { func TestMarshal_InterfaceArray(t *testing.T) {
type Struct1 struct { type Struct1 struct {
Val int32 Val int32
@ -76,16 +89,15 @@ func TestMarshal_InterfaceArray(t *testing.T) {
want: []byte{ want: []byte{
TagList, 0x00, 0x00 /*no name*/, TagCompound, 0, 0, 0, 2, TagList, 0x00, 0x00 /*no name*/, TagCompound, 0, 0, 0, 2,
// 1st element // 1st element
TagCompound, 0x00, 0x00, /*no name*/
TagInt, 0x00, 0x03, 'V', 'a', 'l', 0x00, 0x00, 0x00, 0x03, // 3 TagInt, 0x00, 0x03, 'V', 'a', 'l', 0x00, 0x00, 0x00, 0x03, // 3
TagEnd, TagEnd,
// 2nd element // 2nd element
TagCompound, 0x00, 0x00, /*no name*/
TagFloat, 0x00, 0x03, 'V', 'a', 'l', 0x3e, 0x99, 0x99, 0x9a, // 0.3 TagFloat, 0x00, 0x03, 'V', 'a', 'l', 0x3e, 0x99, 0x99, 0x9a, // 0.3
TagEnd, TagEnd,
}, },
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{} w := &bytes.Buffer{}
@ -105,24 +117,39 @@ func TestMarshal_StructArray(t *testing.T) {
Val int32 Val int32
} }
type Struct2 struct {
T int32
Ele Struct1
}
type StructCont struct {
V []Struct2
}
tests := []struct { tests := []struct {
name string name string
args []Struct1 args StructCont
want []byte want []byte
}{ }{
{ {
name: "One element struct array", name: "One element struct array",
args: []Struct1{{3}, {-10}}, args: StructCont{[]Struct2{{3, Struct1{3}}, {-10, Struct1{-10}}}},
want: []byte{ want: []byte{
TagList, 0x00, 0x00 /*no name*/, TagCompound, 0, 0, 0, 2, TagCompound, 0x00, 0x00,
// 1st element TagList, 0x00, 0x01, 'V', TagCompound, 0, 0, 0, 2,
TagCompound, 0x00, 0x00, /*no name*/ // Struct2
TagInt, 0x00, 0x01, 'T', 0x00, 0x00, 0x00, 0x03,
TagCompound, 0x00, 0x03, 'E', 'l', 'e',
TagInt, 0x00, 0x03, 'V', 'a', 'l', 0x00, 0x00, 0x00, 0x03, // 3 TagInt, 0x00, 0x03, 'V', 'a', 'l', 0x00, 0x00, 0x00, 0x03, // 3
TagEnd, TagEnd,
TagEnd,
// 2nd element // 2nd element
TagCompound, 0x00, 0x00, /*no name*/ TagInt, 0x00, 0x01, 'T', 0xff, 0xff, 0xff, 0xf6,
TagCompound, 0x00, 0x03, 'E', 'l', 'e',
TagInt, 0x00, 0x03, 'V', 'a', 'l', 0xff, 0xff, 0xff, 0xf6, // -10 TagInt, 0x00, 0x03, 'V', 'a', 'l', 0xff, 0xff, 0xff, 0xf6, // -10
TagEnd, TagEnd,
TagEnd,
TagEnd,
}, },
}, },
} }

View File

@ -22,8 +22,13 @@ const (
TagCompound TagCompound
TagIntArray TagIntArray
TagLongArray TagLongArray
TagNone = 0xFF
) )
func IsArrayTag(ty byte) bool {
return ty == TagByteArray || ty == TagIntArray || ty == TagLongArray
}
type DecoderReader = interface { type DecoderReader = interface {
io.ByteScanner io.ByteScanner
io.Reader io.Reader

View File

@ -3,6 +3,7 @@ package nbt
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"io/ioutil"
"reflect" "reflect"
"testing" "testing"
) )
@ -60,9 +61,8 @@ func TestUnmarshal_simple(t *testing.T) {
} }
} }
func TestUnmarshal_bitTest(t *testing.T) { // Generated by vscode-hexdump
// Generated by vscode-hexdump var bigTestData = [...]byte{
data := []byte{
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0xed, 0x54, 0xcf, 0x4f, 0x1a, 0x41, 0x00, 0x00, 0xed, 0x54, 0xcf, 0x4f, 0x1a, 0x41,
0x14, 0x7e, 0xc2, 0x02, 0xcb, 0x96, 0x82, 0xb1, 0x14, 0x7e, 0xc2, 0x02, 0xcb, 0x96, 0x82, 0xb1,
@ -127,44 +127,36 @@ func TestUnmarshal_bitTest(t *testing.T) {
0xc2, 0xec, 0xfe, 0xfc, 0x7a, 0xfb, 0x7d, 0x78, 0xc2, 0xec, 0xfe, 0xfc, 0x7a, 0xfb, 0x7d, 0x78,
0xd3, 0x84, 0xdf, 0xd4, 0xf2, 0xa4, 0xfb, 0x08, 0xd3, 0x84, 0xdf, 0xd4, 0xf2, 0xa4, 0xfb, 0x08,
0x06, 0x00, 0x00, 0x06, 0x00, 0x00,
} }
type BitTestStruct struct { type BigTestStruct struct {
LongTest int64 `nbt:"longTest"`
ShortTest int16 `nbt:"shortTest"`
StringTest string `nbt:"stringTest"`
FloatTest float32 `nbt:"floatTest"`
IntTest int32 `nbt:"intTest"`
NCT struct { NCT struct {
Egg struct {
Name string `nbt:"name"`
Value float32 `nbt:"value"`
} `nbt:"egg"`
Ham struct { Ham struct {
Name string `nbt:"name"` Name string `nbt:"name"`
Value float32 `nbt:"value"` Value float32 `nbt:"value"`
} `nbt:"ham"` } `nbt:"ham"`
} `nbt:"nested compound test"` Egg struct {
IntTest int `nbt:"intTest"`
ByteTest byte `nbt:"byteTest"`
StringTest string `nbt:"stringTest"`
ListTest []int64 `nbt:"listTest (long)"`
DoubleTest float64 `nbt:"doubleTest"`
LongTest int64 `nbt:"longTest"`
ListTest2 [2]struct {
CreatedOn int64 `nbt:"created-on"`
Name string `nbt:"name"` Name string `nbt:"name"`
Value float32 `nbt:"value"`
} `nbt:"egg"`
} `nbt:"nested compound test"`
ListTest []int64 `nbt:"listTest (long)" nbt_type:"noarray"`
ListTest2 [2]struct {
Name string `nbt:"name"`
CreatedOn int64 `nbt:"created-on"`
} `nbt:"listTest (compound)"` } `nbt:"listTest (compound)"`
ByteTest byte `nbt:"byteTest"`
ByteArrayTest []byte `nbt:"byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, starting with n=0 (0, 62, 34, 16, 8, ...))"` ByteArrayTest []byte `nbt:"byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, starting with n=0 (0, 62, 34, 16, 8, ...))"`
ShortTest int16 `nbt:"shortTest"` DoubleTest float64 `nbt:"doubleTest"`
} }
//test parse func MakeBigTestStruct() BigTestStruct {
var value BitTestStruct var want BigTestStruct
r, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
if err := NewDecoder(r).Decode(&value); err != nil {
t.Fatal(err)
}
var want BitTestStruct
want.NCT.Egg.Name = "Eggbert" want.NCT.Egg.Name = "Eggbert"
want.NCT.Egg.Value = 0.5 want.NCT.Egg.Value = 0.5
want.NCT.Ham.Name = "Hampus" want.NCT.Ham.Name = "Hampus"
@ -174,6 +166,7 @@ func TestUnmarshal_bitTest(t *testing.T) {
want.StringTest = "HELLO WORLD THIS IS A TEST STRING \xc3\x85\xc3\x84\xc3\x96!" want.StringTest = "HELLO WORLD THIS IS A TEST STRING \xc3\x85\xc3\x84\xc3\x96!"
want.ListTest = []int64{11, 12, 13, 14, 15} want.ListTest = []int64{11, 12, 13, 14, 15}
want.DoubleTest = 0.49312871321823148 want.DoubleTest = 0.49312871321823148
want.FloatTest = 0.49823147058486938
want.LongTest = 9223372036854775807 want.LongTest = 9223372036854775807
want.ListTest2[0].CreatedOn = 1264099775885 want.ListTest2[0].CreatedOn = 1264099775885
want.ListTest2[0].Name = "Compound tag #0" want.ListTest2[0].Name = "Compound tag #0"
@ -184,14 +177,28 @@ func TestUnmarshal_bitTest(t *testing.T) {
want.ByteArrayTest[n] = byte((n*n*255 + n*7) % 100) want.ByteArrayTest[n] = byte((n*n*255 + n*7) % 100)
} }
want.ShortTest = 32767 want.ShortTest = 32767
return want
}
func TestUnmarshal_bigTest(t *testing.T) {
//test parse
var value BigTestStruct
r, err := gzip.NewReader(bytes.NewReader(bigTestData[:]))
if err != nil {
t.Fatal(err)
}
if err := NewDecoder(r).Decode(&value); err != nil {
t.Fatal(err)
}
want := MakeBigTestStruct()
if !reflect.DeepEqual(value, want) { if !reflect.DeepEqual(value, want) {
t.Errorf("parse fail, expect %v, get %v", want, value) t.Errorf("parse fail, expect %v, get %v", want, value)
} }
//test rawRead //test rawRead
var empty struct{} var empty struct{}
r, err = gzip.NewReader(bytes.NewReader(data)) r, err = gzip.NewReader(bytes.NewReader(bigTestData[:]))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -199,9 +206,8 @@ func TestUnmarshal_bitTest(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
//test unmarshal to interface{}
var inf interface{} var inf interface{}
r, err = gzip.NewReader(bytes.NewReader(data)) r, err = gzip.NewReader(bytes.NewReader(bigTestData[:]))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -211,6 +217,24 @@ func TestUnmarshal_bitTest(t *testing.T) {
// t.Log(inf) // t.Log(inf)
} }
func TestMarshal_bigTest(t *testing.T) {
var b bytes.Buffer
err := MarshalCompound(&b, MakeBigTestStruct(), "Level")
if err != nil {
t.Error(err)
}
rd, _ := gzip.NewReader(bytes.NewReader(bigTestData[:]))
want, err := ioutil.ReadAll(rd)
if err != nil {
t.Error(err)
}
if !bytes.Equal(b.Bytes(), want) {
t.Errorf("got:\n[% 2x]\nwant:\n[% 2x]", b.Bytes(), want)
}
}
func TestUnmarshal_IntArray(t *testing.T) { func TestUnmarshal_IntArray(t *testing.T) {
data := []byte{ data := []byte{
TagIntArray, 0, 0, TagIntArray, 0, 0,