diff --git a/nbt/decode.go b/nbt/decode.go index 96b4ea4..5f68704 100644 --- a/nbt/decode.go +++ b/nbt/decode.go @@ -90,6 +90,8 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { switch vk := val.Kind(); vk { default: return errors.New("cannot parse TagByte as " + vk.String()) + case reflect.Bool: + val.SetBool(value != 0) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: val.SetInt(int64(value)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: diff --git a/nbt/decode_test.go b/nbt/decode_test.go index 8f27907..863a80e 100644 --- a/nbt/decode_test.go +++ b/nbt/decode_test.go @@ -364,6 +364,29 @@ func TestDecoder_Decode_ByteArray(t *testing.T) { // t.Log(infValue) } +func TestDecoder_Decode_bool(t *testing.T) { + data := [][]byte{ + {TagByte, 0, 0, 0}, + {TagByte, 0, 0, 1}, + {TagByte, 0, 0, 2}, + {TagByte, 0, 0, 128}, + {TagByte, 0, 0, 255}, + } + want := []bool{ + false, true, true, true, true, + } + var value bool + for i, v := range data { + //Unmarshal to []byte + if err := Unmarshal(v, &value); err != nil { + t.Fatal(err) + } + if value != want[i] { + t.Errorf("parse fail, expect %v, get %v", want, value) + } + } +} + func TestDecoder_Decode_ErrorString(t *testing.T) { var data = []byte{ 0x08, 0x00, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0xFF, 0xFE, diff --git a/nbt/encode.go b/nbt/encode.go index 18e65a7..cf1e132 100644 --- a/nbt/encode.go +++ b/nbt/encode.go @@ -40,7 +40,6 @@ func NewEncoder(w io.Writer) *Encoder { // which TagByteArray, TagIntArray and TagLongArray. // To force encode them as TagList, add a struct field tag. // -// func (e *Encoder) Encode(v interface{}, tagName string) error { t, val := getTagType(reflect.ValueOf(v)) return e.marshal(val, t, tagName) @@ -65,6 +64,12 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { case TagByte: var err error switch val.Kind() { + case reflect.Bool: + var b byte + if val.Bool() { + b = 1 + } + _, err = e.w.Write([]byte{b}) case reflect.Int8: _, err = e.w.Write([]byte{byte(val.Int())}) case reflect.Uint8: @@ -88,11 +93,27 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { } if tagType == TagByteArray { - _, err := e.w.Write(*(*[]byte)((unsafe.Pointer)(&reflect.SliceHeader{ - Data: val.Pointer(), - Len: val.Len(), - Cap: val.Cap(), - }))) + var data []byte + switch val.Type().Elem().Kind() { + case reflect.Bool: + data = make([]byte, val.Len()) + for i := range data { + if val.Index(i).Bool() { + data[i] = 1 + } else { + data[i] = 0 + } + } + case reflect.Uint8: + data = val.Bytes() + case reflect.Int8: + data = *(*[]byte)((unsafe.Pointer)(&reflect.SliceHeader{ + Data: val.Pointer(), + Len: val.Len(), + Cap: val.Cap(), + })) + } + _, err := e.w.Write(data) return err } else { for i := 0; i < n; i++ { @@ -275,7 +296,7 @@ func getTagType(v reflect.Value) (byte, reflect.Value) { func getTagTypeByType(vk reflect.Type) byte { switch vk.Kind() { - case reflect.Int8, reflect.Uint8: + case reflect.Bool, reflect.Int8, reflect.Uint8: return TagByte case reflect.Int16, reflect.Uint16: return TagShort diff --git a/nbt/encode_test.go b/nbt/encode_test.go index f60b292..f634663 100644 --- a/nbt/encode_test.go +++ b/nbt/encode_test.go @@ -55,6 +55,20 @@ func TestEncoder_encodeInt8(t *testing.T) { } } +func TestEncoder_encodeBool(t *testing.T) { + // Test marshal pure Int array + v := []bool{true, false} + out := []byte{ + TagByteArray, 0x00, 0x00, 0, 0, 0, 2, + 0x01, 0x00, + } + if data, err := Marshal(v); err != nil { + t.Error(err) + } else if !bytes.Equal(data, out) { + t.Errorf("output binary not right: get % 02x, want % 02x ", data, out) + } +} + func TestEncoder_Encode_floatArray(t *testing.T) { // Test marshal pure Int array v := []float32{0.3, -100, float32(math.NaN())}