Add buffer pool

This commit is contained in:
Tnze
2021-11-27 16:17:27 +08:00
parent ab63acbd7e
commit 11dd523542

View File

@ -3,8 +3,8 @@ package packet
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sync"
)
// Packet define a net data package
@ -34,18 +34,24 @@ func (p Packet) Scan(fields ...FieldDecoder) error {
return nil
}
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
// Pack 打包一个数据包
func (p *Packet) Pack(w io.Writer, threshold int) error {
if threshold >= 0 {
return p.withCompression(w)
return p.packWithCompression(w)
} else {
return p.withoutCompression(w)
return p.packWithoutCompression(w)
}
}
func (p *Packet) withoutCompression(w io.Writer) error {
var buf [5]byte
buffer := bytes.NewBuffer(buf[:0])
func (p *Packet) packWithoutCompression(w io.Writer) error {
buffer := bufPool.Get().(*bytes.Buffer)
defer bufPool.Put(buffer)
n, err := VarInt(p.ID).WriteTo(buffer)
if err != nil {
panic(err)
@ -68,9 +74,10 @@ func (p *Packet) withoutCompression(w io.Writer) error {
return nil
}
func (p *Packet) withCompression(w io.Writer) error {
var buff bytes.Buffer
zw := zlib.NewWriter(&buff)
func (p *Packet) packWithCompression(w io.Writer) error {
buff := bufPool.Get().(*bytes.Buffer)
defer bufPool.Put(buff)
zw := zlib.NewWriter(buff)
n1, err := VarInt(p.ID).WriteTo(zw)
if err != nil {
return err
@ -84,8 +91,9 @@ func (p *Packet) withCompression(w io.Writer) error {
return err
}
var dataLength bytes.Buffer
n3, err := VarInt(int(n1) + n2).WriteTo(&dataLength)
dataLength := bufPool.Get().(*bytes.Buffer)
defer bufPool.Put(dataLength)
n3, err := VarInt(int(n1) + n2).WriteTo(dataLength)
if err != nil {
return err
}
@ -110,59 +118,79 @@ func (p *Packet) withCompression(w io.Writer) error {
// UnPack in-place decompression a packet
func (p *Packet) UnPack(r io.Reader, threshold int) error {
var length VarInt
if _, err := length.ReadFrom(r); err != nil {
return err
}
if length < 1 {
return fmt.Errorf("packet length too short")
}
buf := make([]byte, length)
if _, err := io.ReadFull(r, buf); err != nil {
return fmt.Errorf("read content of packet fail: %w", err)
}
buffer := bytes.NewBuffer(buf)
//解压数据
if threshold >= 0 {
if err := unCompress(buffer); err != nil {
return err
return p.unpackWithCompression(r)
} else {
return p.unpackWithoutCompression(r)
}
}
var packetID VarInt
if _, err := packetID.ReadFrom(buffer); err != nil {
return fmt.Errorf("read packet id fail: %v", err)
}
p.ID = int32(packetID)
p.Data = buffer.Bytes()
return nil
}
// unCompress 读取一个压缩的包
func unCompress(data *bytes.Buffer) error {
reader := bytes.NewReader(data.Bytes())
var sizeUncompressed VarInt
if _, err := sizeUncompressed.ReadFrom(reader); err != nil {
func (p *Packet) unpackWithoutCompression(r io.Reader) error {
var Length VarInt
_, err := Length.ReadFrom(r)
if err != nil {
return err
}
var uncompressedData []byte
if sizeUncompressed == 0 {
uncompressedData = data.Bytes()[1:]
} else { // != 0 means compressed, let's decompress
uncompressedData = make([]byte, sizeUncompressed)
r, err := zlib.NewReader(reader)
var PacketID VarInt
n, err := PacketID.ReadFrom(r)
if err != nil {
return fmt.Errorf("decompress fail: %v", err)
return err
}
defer r.Close()
_, err = io.ReadFull(r, uncompressedData)
p.ID = int32(PacketID)
lengthOfData := int(Length) - int(n)
if cap(p.Data) < lengthOfData {
p.Data = make([]byte, lengthOfData)
} else {
p.Data = p.Data[:lengthOfData]
}
_, err = io.ReadFull(r, p.Data)
if err != nil {
return fmt.Errorf("decompress fail: %v", err)
return err
}
return nil
}
func (p *Packet) unpackWithCompression(r io.Reader) error {
var PacketLength VarInt
_, err := PacketLength.ReadFrom(r)
if err != nil {
return err
}
var DataLength VarInt
n2, err := DataLength.ReadFrom(r)
if err != nil {
return err
}
var PacketID VarInt
if DataLength != 0 {
r, err = zlib.NewReader(r)
if err != nil {
return err
}
_, err = PacketID.ReadFrom(r)
if err != nil {
return err
}
} else {
n3, err := PacketID.ReadFrom(r)
if err != nil {
return err
}
DataLength = PacketLength - VarInt(n2) - VarInt(n3)
}
if cap(p.Data) < int(DataLength) {
p.Data = make([]byte, DataLength)
} else {
p.Data = p.Data[:DataLength]
}
p.ID = int32(PacketID)
_, err = io.ReadFull(r, p.Data)
if err != nil {
return err
}
}
*data = *bytes.NewBuffer(uncompressedData)
return nil
}