diff --git a/nbt/encode.go b/nbt/encode.go index a7ee4ff..8350961 100644 --- a/nbt/encode.go +++ b/nbt/encode.go @@ -40,8 +40,8 @@ func NewEncoder(w io.Writer) *Encoder { // You haven't ability to encode them as TagList as root element at this time, // issue or pull-request is welcome. func (e *Encoder) Encode(v interface{}, tagName string) error { - val := reflect.ValueOf(v) - return e.marshal(val, getTagType(val), tagName) + t, val := getTagType(reflect.ValueOf(v)) + return e.marshal(val, t, tagName) } func (e *Encoder) marshal(val reflect.Value, tagType byte, tagName string) error { @@ -101,7 +101,7 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { case TagList: var eleType byte if val.Len() > 0 { - eleType = getTagType(val.Index(0)) + eleType, _ = getTagType(val.Index(0)) } else { eleType = getTagTypeByType(val.Type().Elem()) } @@ -110,8 +110,8 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { } for i := 0; i < val.Len(); i++ { - arrVal := val.Index(i) - err := e.writeValue(arrVal, getTagType(arrVal)) + arrType, arrVal := getTagType(val.Index(i)) + err := e.writeValue(arrVal, arrType) if err != nil { return err } @@ -154,8 +154,7 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { } else { tagName = r.Key().String() } - tagValue := r.Value() - tagType := getTagType(tagValue) + tagType, tagValue := getTagType(r.Value()) if tagType == TagNone { return errors.New("unsupported value " + tagValue.String()) } @@ -172,32 +171,64 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { return nil } -func getTagType(v reflect.Value) byte { +func getTagType(v reflect.Value) (byte, reflect.Value) { + for { + // Load value from interface + if v.Kind() == reflect.Interface && !v.IsNil() { + v = v.Elem() + continue + } + + if v.Kind() != reflect.Ptr { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(NBTEncoder); ok { + return u.TagType(), v + } + } + + v = v.Elem() + } + if v.CanInterface() { if encoder, ok := v.Interface().(NBTEncoder); ok { - return encoder.TagType() + return encoder.TagType(), v } } + switch v.Kind() { case reflect.Array, reflect.Slice: var elemType byte if v.Len() > 0 { - elemType = getTagType(v.Index(0)) + elemType, _ = getTagType(v.Index(0)) } else { elemType = getTagTypeByType(v.Type().Elem()) } switch elemType { case TagByte: // Special types for these values - return TagByteArray + return TagByteArray, v case TagInt: - return TagIntArray + return TagIntArray, v case TagLong: - return TagLongArray + return TagLongArray, v default: - return TagList + return TagList, v } + default: - return getTagTypeByType(v.Type()) + return getTagTypeByType(v.Type()), v } } @@ -217,7 +248,7 @@ func getTagTypeByType(vk reflect.Type) byte { return TagDouble case reflect.String: return TagString - case reflect.Struct, reflect.Interface, reflect.Map: + case reflect.Struct, reflect.Map: return TagCompound default: return TagNone @@ -237,7 +268,7 @@ func parseTag(f reflect.StructField, v reflect.Value, tagName string) (result ta } nbtType := f.Tag.Get("nbt_type") - result.Type = getTagType(v) + result.Type, v = getTagType(v) if strings.Contains(nbtType, "list") { if IsArrayTag(result.Type) { result.Type = TagList // for expanding the array to a standard list diff --git a/nbt/encode_test.go b/nbt/encode_test.go index dafd272..6f7f721 100644 --- a/nbt/encode_test.go +++ b/nbt/encode_test.go @@ -230,3 +230,26 @@ func TestRawMessage_Encode(t *testing.T) { t.Fatalf("Encode error: want %v, get: %v", data, buf.Bytes()) } } + +func TestEncoder_Encode_interface(t *testing.T) { + data := map[string]interface{}{ + "Key": int32(12), + "Value": "Tnze", + } + var buf bytes.Buffer + if err := NewEncoder(&buf).Encode(data, "ab"); err != nil { + t.Fatalf("Encode error: %v", err) + } + + var container struct { + Key int32 + Value string + } + if _, err := NewDecoder(&buf).Decode(&container); err != nil { + t.Fatalf("Decode error: %v", err) + } + + if container.Key != 12 || container.Value != "Tnze" { + t.Fatalf("want: (%v, %v), but got (%v, %v)", 12, "Tnze", container.Key, container.Value) + } +}