Experimental fastnbt library

This commit is contained in:
Tnze
2023-04-24 01:40:22 +08:00
parent de254fb1c6
commit ad3f69e40b
7 changed files with 565 additions and 0 deletions

File diff suppressed because one or more lines are too long

198
nbt/fastnbt/decode.go Normal file
View File

@ -0,0 +1,198 @@
package fastnbt
import (
"errors"
"fmt"
"io"
"github.com/Tnze/go-mc/nbt"
)
//func (v *Value) Parse(data []byte) {
// // TODO
//}
func (v *Value) UnmarshalNBT(tagType byte, r nbt.DecoderReader) error {
v.tag = tagType
var buf [8]byte
switch tagType {
case nbt.TagEnd:
case nbt.TagByte:
n, err := r.ReadByte()
if err != nil {
return err
}
v.data = append(v.data[:0], n)
case nbt.TagShort:
if _, err := r.Read(buf[:2]); err != nil {
return err
}
v.data = append(v.data[:0], buf[:2]...)
case nbt.TagInt, nbt.TagFloat:
if _, err := r.Read(buf[:4]); err != nil {
return err
}
v.data = append(v.data[:0], buf[:4]...)
case nbt.TagLong, nbt.TagDouble:
if _, err := r.Read(buf[:]); err != nil {
return err
}
v.data = append(v.data[:0], buf[:]...)
case nbt.TagByteArray:
n, err := readInt32(r)
if err != nil {
return err
}
v.data = append(v.data[:0], make([]byte, 4+n)...)
v.data[0], v.data[1], v.data[2], v.data[3] = byte(n>>24), byte(n>>16), byte(n>>8), byte(n)
_, err = io.ReadFull(r, v.data[4:])
if err != nil {
return err
}
case nbt.TagString:
n, err := readInt16(r)
if err != nil {
return err
}
v.data = append(v.data[:0], make([]byte, 2+n)...)
v.data[0], v.data[1] = byte(n>>8), byte(n)
_, err = io.ReadFull(r, v.data[2:])
if err != nil {
return err
}
case nbt.TagList:
t, err := r.ReadByte()
if err != nil {
return err
}
length, err := readInt32(r)
if err != nil {
return err
}
v.list = v.list[:0]
for i := int32(0); i < length; i++ {
field := new(Value)
err = field.UnmarshalNBT(t, r)
if err != nil {
return err
}
v.list = append(v.list, field)
}
case nbt.TagCompound:
for {
t, name, err := readTag(r)
if err != nil {
return err
}
if t == nbt.TagEnd {
break
}
field := new(Value)
err = field.UnmarshalNBT(t, r)
if err != nil {
return decodeErr{name, err}
}
v.comp.kvs = append(v.comp.kvs, kv{tag: name, v: field})
}
case nbt.TagIntArray:
n, err := readInt32(r)
if err != nil {
return err
}
v.data = append(v.data[:0], make([]byte, 4+n*4)...)
v.data[0], v.data[1], v.data[2], v.data[3] = byte(n>>24), byte(n>>16), byte(n>>8), byte(n)
_, err = io.ReadFull(r, v.data[4:])
if err != nil {
return err
}
case nbt.TagLongArray:
n, err := readInt32(r)
if err != nil {
return err
}
v.data = append(v.data[:0], make([]byte, 4+n*8)...)
v.data[0], v.data[1], v.data[2], v.data[3] = byte(n>>24), byte(n>>16), byte(n>>8), byte(n)
_, err = io.ReadFull(r, v.data[4:])
if err != nil {
return err
}
}
return nil
}
func readTag(r nbt.DecoderReader) (tagType byte, tagName string, err error) {
tagType, err = r.ReadByte()
if err != nil {
return
}
switch tagType {
// case 0x1f, 0x78:
case nbt.TagEnd:
default: // Read Tag
tagName, err = readString(r)
}
return
}
func readInt16(r nbt.DecoderReader) (int16, error) {
var data [2]byte
_, err := io.ReadFull(r, data[:])
return int16(data[0])<<8 | int16(data[1]), err
}
func readInt32(r nbt.DecoderReader) (int32, error) {
var data [4]byte
_, err := io.ReadFull(r, data[:])
return int32(data[0])<<24 | int32(data[1])<<16 |
int32(data[2])<<8 | int32(data[3]), err
}
func readString(r nbt.DecoderReader) (string, error) {
length, err := readInt16(r)
if err != nil {
return "", err
} else if length < 0 {
return "", errors.New("string length less than 0")
}
var str string
if length > 0 {
buf := make([]byte, length)
_, err = io.ReadFull(r, buf)
str = string(buf)
}
return str, err
}
type decodeErr struct {
decoding string
err error
}
func (d decodeErr) Error() string {
return fmt.Sprintf("fail to decode tag %q: %v", d.decoding, d.err)
}

View File

@ -0,0 +1,98 @@
package fastnbt
import (
"bytes"
_ "embed"
"testing"
"github.com/Tnze/go-mc/nbt"
)
//go:embed bigTest_test.snbt
var bigTestSNBT string
func TestValue_UnmarshalNBT(t *testing.T) {
data, err := nbt.Marshal(nbt.StringifiedMessage(bigTestSNBT))
if err != nil {
t.Fatal(err)
}
var val Value
err = nbt.Unmarshal(data, &val)
if err != nil {
t.Fatal(err)
}
if v := val.Get("longTest"); v == nil {
t.Fail()
} else if got, want := v.Long(), int64(9223372036854775807); got != want {
t.Errorf("expect %v, got: %v", want, got)
}
if v := val.Get("shortTest"); v == nil {
t.Fail()
} else if got, want := v.Short(), int16(32767); got != want {
t.Errorf("expect %v, got: %v", want, got)
}
if v := val.Get("stringTest"); v == nil {
t.Fail()
} else if got, want := v.String(), "HELLO WORLD THIS IS A TEST STRING ÅÄÖ!"; got != want {
t.Errorf("expect %s, got: %s", want, got)
}
if v := val.Get("floatTest"); v == nil {
t.Fail()
} else if got, want := v.Float(), float32(0.49823147); got != want {
t.Errorf("expect %v, got: %v", want, got)
}
if v := val.Get("byteTest"); v == nil {
t.Fail()
} else if got, want := v.Byte(), int8(127); got != want {
t.Errorf("expect %v, got: %v", want, got)
}
if v := val.Get("intTest"); v == nil {
t.Fail()
} else if got, want := v.Int(), int32(2147483647); got != want {
t.Errorf("expect %v, got: %v", want, got)
}
if v := val.Get("nested compound test"); v == nil {
t.Fail()
} else if v = v.Get("ham"); v == nil {
t.Fail()
} else if v = v.Get("name"); v == nil {
t.Fail()
} else if got, want := v.String(), "Hampus"; got != want {
t.Errorf("expect %v, got: %v", want, got)
}
if v := val.Get("nested compound test", "ham", "name"); v == nil {
t.Fail()
} else if got, want := v.String(), "Hampus"; got != want {
t.Errorf("expect %v, got: %v", want, got)
}
if v := val.Get("listTest (long)"); v == nil {
t.Fail()
} else if list := v.List(); list == nil {
t.Fail()
} else if len(list) != 5 {
t.Fail()
} else if list[0].Long() != 11 || list[1].Long() != 12 || list[2].Long() != 13 || list[3].Long() != 14 || list[4].Long() != 15 {
t.Fail()
}
want := make([]byte, 1000)
for n := 0; n < 1000; n++ {
want[n] = byte((n*n*255 + n*7) % 100)
}
if v := val.Get("byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, starting with n=0 (0, 62, 34, 16, 8, ...))"); v == nil {
t.Fail()
} else if got := v.ByteArray(); !bytes.Equal(got, want) {
t.Errorf("expect %v", want)
t.Errorf(" got: %v", got)
}
}

90
nbt/fastnbt/encode.go Normal file
View File

@ -0,0 +1,90 @@
package fastnbt
import (
"errors"
"io"
"github.com/Tnze/go-mc/nbt"
)
func (v *Value) TagType() byte { return v.tag }
func (v *Value) MarshalNBT(w io.Writer) (err error) {
switch v.tag {
case nbt.TagEnd:
case nbt.TagByte, nbt.TagShort, nbt.TagInt, nbt.TagLong, nbt.TagFloat, nbt.TagDouble,
nbt.TagByteArray, nbt.TagString, nbt.TagIntArray, nbt.TagLongArray:
_, err = w.Write(v.data)
case nbt.TagList:
// Take a look at the first element's tag.
// If length == 0, use TagEnd
elemType := nbt.TagEnd
length := len(v.list)
if length > 0 {
elemType = v.list[0].tag
}
_, err = w.Write([]byte{elemType})
if err != nil {
return
}
err = writeInt32(w, int32(length))
if err != nil {
return
}
for _, val := range v.list {
err = val.MarshalNBT(w)
if err != nil {
return
}
}
case nbt.TagCompound:
for _, field := range v.comp.kvs {
err = writeTag(w, field.v.tag, field.tag)
if err != nil {
return
}
err = field.v.MarshalNBT(w)
if err != nil {
return
}
}
_, err = w.Write([]byte{nbt.TagEnd})
if err != nil {
return
}
default:
err = errors.New("internal: unknown tag")
}
return
}
func writeTag(w io.Writer, tagType byte, tagName string) error {
if _, err := w.Write([]byte{tagType}); err != nil {
return err
}
bName := []byte(tagName)
if err := writeInt16(w, int16(len(bName))); err != nil {
return err
}
_, err := w.Write(bName)
return err
}
func writeInt16(w io.Writer, n int16) error {
_, err := w.Write([]byte{byte(n >> 8), byte(n)})
return err
}
func writeInt32(w io.Writer, n int32) error {
_, err := w.Write([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
return err
}

View File

@ -0,0 +1 @@
package fastnbt

100
nbt/fastnbt/types.go Normal file
View File

@ -0,0 +1,100 @@
package fastnbt
import (
"math"
"github.com/Tnze/go-mc/nbt"
)
type Value struct {
comp Compound
list []*Value
data []byte
tag byte // nbt.Tag*
}
func (v *Value) Bool() bool {
if v.tag != nbt.TagByte {
return false
}
return v.data[0] != 0
}
func (v *Value) Byte() int8 {
if v.tag != nbt.TagByte {
return 0
}
return int8(v.data[0])
}
func (v *Value) Short() int16 {
if v.tag != nbt.TagShort {
return 0
}
return int16(v.data[0])<<8 | int16(v.data[1])
}
func (v *Value) Int() int32 {
if v.tag != nbt.TagInt {
return 0
}
return int32(v.data[0])<<24 | int32(v.data[1])<<16 |
int32(v.data[2])<<8 | int32(v.data[3])
}
func (v *Value) Long() int64 {
if v.tag != nbt.TagLong {
return 0
}
return int64(v.data[0])<<56 | int64(v.data[1])<<48 |
int64(v.data[2])<<40 | int64(v.data[3])<<32 |
int64(v.data[4])<<24 | int64(v.data[5])<<16 |
int64(v.data[6])<<8 | int64(v.data[7])
}
func (v *Value) Float() float32 {
if v.tag != nbt.TagFloat {
return 0
}
return math.Float32frombits(
uint32(v.data[0])<<24 | uint32(v.data[1])<<16 |
uint32(v.data[2])<<8 | uint32(v.data[3]))
}
func (v *Value) Double() float64 {
if v.tag != nbt.TagDouble {
return 0
}
return math.Float64frombits(
uint64(v.data[0])<<56 | uint64(v.data[1])<<48 |
uint64(v.data[2])<<40 | uint64(v.data[3])<<32 |
uint64(v.data[4])<<24 | uint64(v.data[5])<<16 |
uint64(v.data[6])<<8 | uint64(v.data[7]))
}
func (v *Value) List() []*Value {
return v.list
}
func (v *Value) ByteArray() []byte {
if v.tag != nbt.TagByteArray {
return nil
}
return v.data[4:]
}
func (v *Value) String() string {
if v.tag != nbt.TagString {
return ""
}
return string(v.data[2:])
}
type Compound struct {
kvs []kv
}
type kv struct {
tag string
v *Value
}

47
nbt/fastnbt/update.go Normal file
View File

@ -0,0 +1,47 @@
package fastnbt
import "github.com/Tnze/go-mc/nbt"
func (v *Value) Set(key string, val *Value) {
if v.tag != nbt.TagCompound {
panic("cannot set non-Compound Tag")
}
v.comp.Set(key, val)
}
func (v *Value) Get(keys ...string) *Value {
for _, key := range keys {
if v.tag == nbt.TagCompound {
v = v.comp.Get(key)
if v == nil {
return nil
}
} else {
return nil
}
}
return v
}
func (c *Compound) Set(key string, val *Value) {
for i := range c.kvs {
if c.kvs[i].tag == key {
c.kvs[i].v = val
return
}
}
c.kvs = append(c.kvs, kv{key, val})
}
func (c *Compound) Get(key string) *Value {
for _, tag := range c.kvs {
if tag.tag == key {
return tag.v
}
}
return nil
}
func (c *Compound) Len() int {
return len(c.kvs)
}