diff --git a/net/packet/packet.go b/net/packet/packet.go index 22e239c..9737ef2 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -3,8 +3,8 @@ package packet import ( "bytes" "compress/zlib" - "fmt" "io" + "sync" ) // Packet define a net data package @@ -34,18 +34,24 @@ func (p Packet) Scan(fields ...FieldDecoder) error { return nil } +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + // Pack 打包一个数据包 func (p *Packet) Pack(w io.Writer, threshold int) error { if threshold >= 0 { - return p.withCompression(w) + return p.packWithCompression(w) } else { - return p.withoutCompression(w) + return p.packWithoutCompression(w) } } -func (p *Packet) withoutCompression(w io.Writer) error { - var buf [5]byte - buffer := bytes.NewBuffer(buf[:0]) +func (p *Packet) packWithoutCompression(w io.Writer) error { + buffer := bufPool.Get().(*bytes.Buffer) + defer bufPool.Put(buffer) n, err := VarInt(p.ID).WriteTo(buffer) if err != nil { panic(err) @@ -68,9 +74,10 @@ func (p *Packet) withoutCompression(w io.Writer) error { return nil } -func (p *Packet) withCompression(w io.Writer) error { - var buff bytes.Buffer - zw := zlib.NewWriter(&buff) +func (p *Packet) packWithCompression(w io.Writer) error { + buff := bufPool.Get().(*bytes.Buffer) + defer bufPool.Put(buff) + zw := zlib.NewWriter(buff) n1, err := VarInt(p.ID).WriteTo(zw) if err != nil { return err @@ -84,8 +91,9 @@ func (p *Packet) withCompression(w io.Writer) error { return err } - var dataLength bytes.Buffer - n3, err := VarInt(int(n1) + n2).WriteTo(&dataLength) + dataLength := bufPool.Get().(*bytes.Buffer) + defer bufPool.Put(dataLength) + n3, err := VarInt(int(n1) + n2).WriteTo(dataLength) if err != nil { return err } @@ -110,59 +118,79 @@ func (p *Packet) withCompression(w io.Writer) error { // UnPack in-place decompression a packet func (p *Packet) UnPack(r io.Reader, threshold int) error { - var length VarInt - if _, err := length.ReadFrom(r); err != nil { + if threshold >= 0 { + return p.unpackWithCompression(r) + } else { + return p.unpackWithoutCompression(r) + } +} + +func (p *Packet) unpackWithoutCompression(r io.Reader) error { + var Length VarInt + _, err := Length.ReadFrom(r) + if err != nil { return err } - if length < 1 { - return fmt.Errorf("packet length too short") - } - buf := make([]byte, length) - if _, err := io.ReadFull(r, buf); err != nil { - return fmt.Errorf("read content of packet fail: %w", err) - } - buffer := bytes.NewBuffer(buf) - //解压数据 - if threshold >= 0 { - if err := unCompress(buffer); err != nil { + var PacketID VarInt + n, err := PacketID.ReadFrom(r) + if err != nil { + return err + } + p.ID = int32(PacketID) + + lengthOfData := int(Length) - int(n) + if cap(p.Data) < lengthOfData { + p.Data = make([]byte, lengthOfData) + } else { + p.Data = p.Data[:lengthOfData] + } + _, err = io.ReadFull(r, p.Data) + if err != nil { + return err + } + return nil +} + +func (p *Packet) unpackWithCompression(r io.Reader) error { + var PacketLength VarInt + _, err := PacketLength.ReadFrom(r) + if err != nil { + return err + } + + var DataLength VarInt + n2, err := DataLength.ReadFrom(r) + if err != nil { + return err + } + + var PacketID VarInt + if DataLength != 0 { + r, err = zlib.NewReader(r) + if err != nil { return err } + _, err = PacketID.ReadFrom(r) + if err != nil { + return err + } + } else { + n3, err := PacketID.ReadFrom(r) + if err != nil { + return err + } + DataLength = PacketLength - VarInt(n2) - VarInt(n3) } - - var packetID VarInt - if _, err := packetID.ReadFrom(buffer); err != nil { - return fmt.Errorf("read packet id fail: %v", err) + if cap(p.Data) < int(DataLength) { + p.Data = make([]byte, DataLength) + } else { + p.Data = p.Data[:DataLength] } - p.ID = int32(packetID) - p.Data = buffer.Bytes() - return nil -} - -// unCompress 读取一个压缩的包 -func unCompress(data *bytes.Buffer) error { - reader := bytes.NewReader(data.Bytes()) - - var sizeUncompressed VarInt - if _, err := sizeUncompressed.ReadFrom(reader); err != nil { + p.ID = int32(PacketID) + _, err = io.ReadFull(r, p.Data) + if err != nil { return err } - - var uncompressedData []byte - if sizeUncompressed == 0 { - uncompressedData = data.Bytes()[1:] - } else { // != 0 means compressed, let's decompress - uncompressedData = make([]byte, sizeUncompressed) - r, err := zlib.NewReader(reader) - if err != nil { - return fmt.Errorf("decompress fail: %v", err) - } - defer r.Close() - _, err = io.ReadFull(r, uncompressedData) - if err != nil { - return fmt.Errorf("decompress fail: %v", err) - } - } - *data = *bytes.NewBuffer(uncompressedData) return nil }