diff --git a/nbt/decode.go b/nbt/decode.go index dd30411..5a63e20 100644 --- a/nbt/decode.go +++ b/nbt/decode.go @@ -51,13 +51,10 @@ func (d *Decoder) checkCompressed(head byte) (compress string) { var ErrEND = errors.New("unexpected TAG_End") func (d *Decoder) unmarshal(val reflect.Value, tagType byte, tagName string) error { - if val.CanInterface() { - if i, ok := val.Interface().(Unmarshaler); ok { - return i.Unmarshal(tagType, tagName, d.r) - } + u, val := indirect(val, tagType == TagEnd) + if u != nil { + return u.Unmarshal(tagType, tagName, d.r) } - // TODO: use function like json.indirect() to handle pointer better - val = val.Elem() switch tagType { default: @@ -358,6 +355,70 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte, tagName string) err return nil } +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// If it encounters an Unmarshaler, indirect stops and returns that. +// If decodingNull is true, indirect stops at the first settable pointer so it +// can be set to nil. +// +// This function is copied and modified from encoding/json +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if decodingNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, reflect.Value{} + } + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + return nil, v +} + // rawRead read and discard a value func (d *Decoder) rawRead(tagType byte) error { var buf [8]byte