diff --git a/nbt/README.md b/nbt/README.md index 692e253..cf157f5 100644 --- a/nbt/README.md +++ b/nbt/README.md @@ -6,7 +6,7 @@ The API is very similar to the standard library `encoding/json`. (But fix some its problem) If you (high probability) have used that, it is easy to use this. -## Supported Struct Tags +## Supported Struct Tags and Options - `nbt` - The primary tag name. See below. - `nbtkey` - The key name of the field (Used to support commas `,` in tag names) @@ -49,7 +49,7 @@ type MyStruct struct { } ``` -### The `nbtkey` +### The `nbtkey` tag Common issue with JSON standard libraries: inability to specify keys containing commas for structures. (e.g `{"a,b" : "c"}`) diff --git a/nbt/decode.go b/nbt/decode.go index 1533176..a6a56d7 100644 --- a/nbt/decode.go +++ b/nbt/decode.go @@ -8,6 +8,7 @@ import ( "io" "math" "reflect" + "strings" ) // Unmarshal decode binary NBT data and fill into v @@ -351,11 +352,19 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } case TagCompound: + u, ut, val, assign := indirect(val, false) + if assign != nil { + defer assign() + } + if u != nil { + return u.UnmarshalNBT(tagType, d.r) + } + if ut != nil { + return errors.New("cannot decode TagCompound as string") + } switch vk := val.Kind(); vk { - default: - return errors.New("cannot parse TagCompound as " + vk.String()) case reflect.Struct: - tinfo := typeFields(val.Type()) + fields := cachedTypeFields(val.Type()) for { tt, tn, err := d.readTag() if err != nil { @@ -364,9 +373,37 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { if tt == TagEnd { break } - field := tinfo.findIndexByName(tn) - if field != -1 { - err = d.unmarshal(val.Field(field), tt) + var f *field + if i, ok := fields.nameIndex[tn]; ok { + f = &fields.list[i] + } else { + // Fall back to linear search. + for i := range fields.list { + ff := &fields.list[i] + if strings.EqualFold(ff.name, tn) { + f = ff + break + } + } + } + if f != nil { + val := val + for _, i := range f.index { + if val.Kind() == reflect.Pointer { + if val.IsNil() { + // If a struct embeds a pointer to an unexported type, + // it is not possible to set a newly allocated value + // since the field is unexported. + if !val.CanSet() { + return fmt.Errorf("cannot set embedded pointer to unexported struct: %v", val.Type().Elem()) + } + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } + val = val.Field(i) + } + err = d.unmarshal(val, tt) if err != nil { return fmt.Errorf("fail to decode tag %q: %w", tn, err) } @@ -377,11 +414,12 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { } } case reflect.Map: - if val.Type().Key().Kind() != reflect.String { + vt := val.Type() + if vt.Key().Kind() != reflect.String { return errors.New("cannot parse TagCompound as " + val.Type().String()) } if val.IsNil() { - val.Set(reflect.MakeMap(val.Type())) + val.Set(reflect.MakeMap(vt)) } for { tt, tn, err := d.readTag() @@ -414,6 +452,8 @@ func (d *Decoder) unmarshal(val reflect.Value, tagType byte) error { buf[tn] = value } val.Set(reflect.ValueOf(buf)) + default: + return errors.New("cannot parse TagCompound as " + vk.String()) } } diff --git a/nbt/decode_test.go b/nbt/decode_test.go index e537a2c..623b351 100644 --- a/nbt/decode_test.go +++ b/nbt/decode_test.go @@ -466,7 +466,7 @@ func TestDecoder_Decode_ErrorUnknownField(t *testing.T) { func TestDecoder_Decode_keysWithComma(t *testing.T) { data := []byte{ TagCompound, 0, 1, 'S', - TagString, 0, 1, 'T', + TagString, 0, 1, 't', 0, 4, 'T', 'n', 'z', 'e', TagEnd, } diff --git a/nbt/encode.go b/nbt/encode.go index 5086109..a77e197 100644 --- a/nbt/encode.go +++ b/nbt/encode.go @@ -181,9 +181,22 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { switch val.Kind() { case reflect.Struct: - fields := typeFields(val.Type()) - for _, t := range fields.fields { - v := val.Field(t.index) + fields := cachedTypeFields(val.Type()) + FieldLoop: + for i := range fields.list { + t := &fields.list[i] + + v := val + for _, i := range t.index { + if v.Kind() == reflect.Pointer { + if v.IsNil() { + continue FieldLoop + } + v = v.Elem() + } + v = v.Field(i) + } + if t.omitEmpty && isEmptyValue(v) { continue } @@ -192,7 +205,7 @@ func (e *Encoder) writeValue(val reflect.Value, tagType byte) error { return fmt.Errorf("encode %q error: unsupport type %v", t.name, v.Type()) } - if t.list { + if t.asList { switch typ { case TagByteArray, TagIntArray, TagLongArray: typ = TagList // override the parsed type diff --git a/nbt/interface.go b/nbt/interface.go index 7225497..3053ce4 100644 --- a/nbt/interface.go +++ b/nbt/interface.go @@ -10,3 +10,16 @@ type Marshaler interface { TagType() byte MarshalNBT(w io.Writer) error } + +// FieldsUnmarshaler is a type can hold many Tags just like a TagCompound. +// +// If and only if a type which implements this interface is used as an anonymous field of a struct, +// and didn't set a struct tag, the content it holds will be considered as in the outer struct. +type FieldsUnmarshaler interface { + UnmarshalField(tagType byte, tagName string, r DecoderReader) (ok bool, err error) +} + +// FieldsMarshaler is similar to FieldsUnmarshaler, but for marshaling. +type FieldsMarshaler interface { + MarshalFields(w io.Writer) (ok bool, err error) +} diff --git a/nbt/snbt_scanner.go b/nbt/snbt_scanner.go index da2b078..f605b52 100644 --- a/nbt/snbt_scanner.go +++ b/nbt/snbt_scanner.go @@ -20,7 +20,7 @@ const ( // These values are stored in the parseState stack. // They give the current state of a composite value -// being scanned. If the parser is inside a nested value +// being scanned. If the parser is inside a nested value, // the parseState describes the nested state, outermost at entry 0. const ( parseCompoundName = iota // parsing tag name (before colon) diff --git a/nbt/special_test.go b/nbt/special_test.go new file mode 100644 index 0000000..3109b06 --- /dev/null +++ b/nbt/special_test.go @@ -0,0 +1,82 @@ +package nbt_test + +import ( + "fmt" + "testing" + + "github.com/Tnze/go-mc/nbt" +) + +func ExampleMarshal_anonymousStructField() { + type A struct{ F string } + type B struct{ E string } + type S struct { + A // anonymous fields are usually marshaled as if their inner exported fields were fields in the outer struct + B `nbt:"B"` // anonymous field, but with an explicit tag name specified + } + + var val S + val.F = "Tnze" + val.E = "GoMC" + + data, err := nbt.Marshal(val) + if err != nil { + panic(err) + } + + var snbt nbt.StringifiedMessage + if err := nbt.Unmarshal(data, &snbt); err != nil { + panic(err) + } + fmt.Println(snbt) + + // Output: + // {F:Tnze,B:{E:GoMC}} +} + +func ExampleUnmarshal_anonymousStructField() { + type A struct{ F string } + type B struct{ E string } + type S struct { + A // anonymous fields are usually marshaled as if their inner exported fields were fields in the outer struct + B `nbt:"B"` // anonymous field, but with an explicit tag name specified + } + + data, err := nbt.Marshal(nbt.StringifiedMessage(`{F:Tnze,B:{E:GoMC}}`)) + if err != nil { + panic(err) + } + + var val S + if err := nbt.Unmarshal(data, &val); err != nil { + panic(err) + } + fmt.Println(val.F) + fmt.Println(val.E) + + // Output: + // Tnze + // GoMC +} + +func TestMarshal_anonymousPointerNesting(t *testing.T) { + type A struct{ T string } + type B struct{ *A } + type C struct{ B } + + val := C{B{&A{"Tnze"}}} + + data, err := nbt.Marshal(val) + if err != nil { + panic(err) + } + + var snbt nbt.StringifiedMessage + if err := nbt.Unmarshal(data, &snbt); err != nil { + panic(err) + } + want := `{T:Tnze}` + if string(snbt) != want { + t.Errorf("Marshal nesting anonymous struct error, got %q, want %q", snbt, want) + } +} diff --git a/nbt/typeinfo.go b/nbt/typeinfo.go index bec5533..5cc9638 100644 --- a/nbt/typeinfo.go +++ b/nbt/typeinfo.go @@ -2,84 +2,241 @@ package nbt import ( "reflect" + "sort" "strings" "sync" ) -type typeInfo struct { - fields []structField - nameToIndex map[string]int // index of the field in struct, not previous slice +type structFields struct { + list []field + nameIndex map[string]int // index of the previous slice. } -type structField struct { - name string - index int +type field struct { + name string + tag bool + index []int + typ reflect.Type omitEmpty bool - list bool + asList bool } -var tInfoMap sync.Map +// byIndex sorts field by index sequence. +type byIndex []field -func typeFields(typ reflect.Type) *typeInfo { - if ti, ok := tInfoMap.Load(typ); ok { - return ti.(*typeInfo) +func (x byIndex) Len() int { return len(x) } + +func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byIndex) Less(i, j int) bool { + for k, xik := range x[i].index { + if k >= len(x[j].index) { + return false + } + if xik != x[j].index[k] { + return xik < x[j].index[k] + } } + return len(x[i].index) < len(x[j].index) +} - tInfo := new(typeInfo) - tInfo.nameToIndex = make(map[string]int) - if typ.Kind() == reflect.Struct { - n := typ.NumField() - tInfo.fields = make([]structField, 0, n) - for i := 0; i < n; i++ { - f := typ.Field(i) - tag := f.Tag.Get("nbt") - if (f.PkgPath != "" && !f.Anonymous) || tag == "-" { - continue // Private field +func typeFields(t reflect.Type) (tInfo structFields) { + // Anonymous fields to explore at the current level and the next. + current := []field{} + next := []field{{typ: t}} + + // Count of queued names for current level and the next. + var count, nextCount map[reflect.Type]int + + // Types already visited at an earlier level. + visited := make(map[reflect.Type]struct{}) + + // Fields found. + var fields []field + + for len(next) > 0 { + current, next = next, current[:0] + count, nextCount = nextCount, make(map[reflect.Type]int) + + for _, f := range current { + if _, ok := visited[f.typ]; ok { + continue } + visited[f.typ] = struct{}{} - // parse tags - var field structField - name, opts, _ := strings.Cut(tag, ",") - if keytag := f.Tag.Get("nbtkey"); keytag != "" { - name = keytag - } else if name == "" { - name = f.Name - } - field.name = name - field.index = i - - // parse options - for opts != "" { - var name string - name, opts, _ = strings.Cut(opts, ",") - switch name { - case "omitempty": - field.omitEmpty = true - case "list": - field.list = true + // Scan f.typ for fields to include. + for i := 0; i < f.typ.NumField(); i++ { + sf := f.typ.Field(i) + if sf.Anonymous { + t := sf.Type + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + if !sf.IsExported() && t.Kind() != reflect.Struct { + // Ignore embedded fields of unexported non-struct types. + continue + } + // Do not ignore embedded fields of unexported struct types + // since they may have exported fields. + } else if !sf.IsExported() { + // Ignore unexported non-embedded fields. + continue } - } - if f.Tag.Get("nbt_type") == "list" { - field.list = true - } - tInfo.fields = append(tInfo.fields, field) - tInfo.nameToIndex[field.name] = i - if _, ok := tInfo.nameToIndex[f.Name]; !ok { - tInfo.nameToIndex[f.Name] = i + tag := sf.Tag.Get("nbt") + if tag == "-" { + continue + } + + // parse tags + name, opts, _ := strings.Cut(tag, ",") + index := make([]int, len(f.index)+1) + copy(index, f.index) + index[len(f.index)] = i + if keytag := sf.Tag.Get("nbtkey"); keytag != "" { + name = keytag + } + + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Pointer { + // Follow pointer. + ft = ft.Elem() + } + + // parse options + var omitEmpty, asList bool + for opts != "" { + var name string + name, opts, _ = strings.Cut(opts, ",") + switch name { + case "omitempty": + omitEmpty = true + case "list": + asList = true + } + } + // Deprecated: use `nbt:",list"` instead. + if sf.Tag.Get("nbt_type") == "list" { + asList = true + } + + // Record found field and index sequence. + if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { + tagged := name != "" + if name == "" { + name = sf.Name + } + field := field{ + name: name, + tag: tagged, + index: index, + typ: ft, + omitEmpty: omitEmpty, + asList: asList, + } + + fields = append(fields, field) + if count[f.typ] > 1 { + // If there were multiple instances, add a second, + // so that the annihilation code will see a duplicate. + // It only cares about the distinction between 1 or 2, + // so don't bother generating any more copies. + fields = append(fields, fields[len(fields)-1]) + } + continue + } + + // Record new anonymous struct to explore in next round. + nextCount[ft]++ + if nextCount[ft] == 1 { + next = append(next, field{name: ft.Name(), index: index, typ: ft}) + } } } } - ti, _ := tInfoMap.LoadOrStore(typ, tInfo) - return ti.(*typeInfo) + sort.Slice(fields, func(i, j int) bool { + x := fields + // sort field by name, breaking ties with depth, then + // breaking ties with "name came from json tag", then + // breaking ties with index sequence. + if x[i].name != x[j].name { + return x[i].name < x[j].name + } + if len(x[i].index) != len(x[j].index) { + return len(x[i].index) < len(x[j].index) + } + if x[i].tag != x[j].tag { + return x[i].tag + } + return byIndex(x).Less(i, j) + }) + + // Delete all fields that are hidden by the Go rules for embedded fields, + // except that fields with JSON tags are promoted. + + // The fields are sorted in primary order of name, secondary order + // of field index length. Loop over names; for each name, delete + // hidden fields by choosing the one dominant field that survives. + out := fields[:0] + for advance, i := 0, 0; i < len(fields); i += advance { + // One iteration per name. + // Find the sequence of fields with the name of this first field. + fi := fields[i] + name := fi.name + for advance = 1; i+advance < len(fields); advance++ { + fj := fields[i+advance] + if fj.name != name { + break + } + } + if advance == 1 { // Only one field with this name + out = append(out, fi) + continue + } + dominant, ok := dominantField(fields[i : i+advance]) + if ok { + out = append(out, dominant) + } + } + + fields = out + sort.Sort(byIndex(fields)) + + nameIndex := make(map[string]int, len(fields)) + for i, field := range fields { + nameIndex[field.name] = i + } + return structFields{ + list: fields, + nameIndex: nameIndex, + } } -func (t *typeInfo) findIndexByName(name string) int { - i, ok := t.nameToIndex[name] - if !ok { - return -1 +// dominantField looks through the fields, all of which are known to +// have the same name, to find the single field that dominates the +// others using Go's embedding rules, modified by the presence of +// NBT tags. If there are multiple top-level fields, the boolean +// will be false: This condition is an error in Go and we skip all +// the fields. +func dominantField(fields []field) (field, bool) { + // The fields are sorted in increasing index-length order, then by presence of tag. + // That means that the first field is the dominant one. We need only check + // for error cases: two fields at top level, either both tagged or neither tagged. + if len(fields) > 1 && len(fields[0].index) == len(fields[1].index) && fields[0].tag == fields[1].tag { + return field{}, false } - return i + return fields[0], true +} + +var fieldCache sync.Map + +func cachedTypeFields(t reflect.Type) structFields { + if ti, ok := fieldCache.Load(t); ok { + return ti.(structFields) + } + tInfo := typeFields(t) + ti, _ := fieldCache.LoadOrStore(t, tInfo) + return ti.(structFields) }