Add buffer pool

This commit is contained in:
Tnze
2021-11-27 16:17:27 +08:00
parent ab63acbd7e
commit 11dd523542

View File

@ -3,8 +3,8 @@ package packet
import ( import (
"bytes" "bytes"
"compress/zlib" "compress/zlib"
"fmt"
"io" "io"
"sync"
) )
// Packet define a net data package // Packet define a net data package
@ -34,18 +34,24 @@ func (p Packet) Scan(fields ...FieldDecoder) error {
return nil return nil
} }
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
// Pack 打包一个数据包 // Pack 打包一个数据包
func (p *Packet) Pack(w io.Writer, threshold int) error { func (p *Packet) Pack(w io.Writer, threshold int) error {
if threshold >= 0 { if threshold >= 0 {
return p.withCompression(w) return p.packWithCompression(w)
} else { } else {
return p.withoutCompression(w) return p.packWithoutCompression(w)
} }
} }
func (p *Packet) withoutCompression(w io.Writer) error { func (p *Packet) packWithoutCompression(w io.Writer) error {
var buf [5]byte buffer := bufPool.Get().(*bytes.Buffer)
buffer := bytes.NewBuffer(buf[:0]) defer bufPool.Put(buffer)
n, err := VarInt(p.ID).WriteTo(buffer) n, err := VarInt(p.ID).WriteTo(buffer)
if err != nil { if err != nil {
panic(err) panic(err)
@ -68,9 +74,10 @@ func (p *Packet) withoutCompression(w io.Writer) error {
return nil return nil
} }
func (p *Packet) withCompression(w io.Writer) error { func (p *Packet) packWithCompression(w io.Writer) error {
var buff bytes.Buffer buff := bufPool.Get().(*bytes.Buffer)
zw := zlib.NewWriter(&buff) defer bufPool.Put(buff)
zw := zlib.NewWriter(buff)
n1, err := VarInt(p.ID).WriteTo(zw) n1, err := VarInt(p.ID).WriteTo(zw)
if err != nil { if err != nil {
return err return err
@ -84,8 +91,9 @@ func (p *Packet) withCompression(w io.Writer) error {
return err return err
} }
var dataLength bytes.Buffer dataLength := bufPool.Get().(*bytes.Buffer)
n3, err := VarInt(int(n1) + n2).WriteTo(&dataLength) defer bufPool.Put(dataLength)
n3, err := VarInt(int(n1) + n2).WriteTo(dataLength)
if err != nil { if err != nil {
return err return err
} }
@ -110,59 +118,79 @@ func (p *Packet) withCompression(w io.Writer) error {
// UnPack in-place decompression a packet // UnPack in-place decompression a packet
func (p *Packet) UnPack(r io.Reader, threshold int) error { func (p *Packet) UnPack(r io.Reader, threshold int) error {
var length VarInt
if _, err := length.ReadFrom(r); err != nil {
return err
}
if length < 1 {
return fmt.Errorf("packet length too short")
}
buf := make([]byte, length)
if _, err := io.ReadFull(r, buf); err != nil {
return fmt.Errorf("read content of packet fail: %w", err)
}
buffer := bytes.NewBuffer(buf)
//解压数据
if threshold >= 0 { if threshold >= 0 {
if err := unCompress(buffer); err != nil { return p.unpackWithCompression(r)
return err } else {
return p.unpackWithoutCompression(r)
} }
} }
var packetID VarInt func (p *Packet) unpackWithoutCompression(r io.Reader) error {
if _, err := packetID.ReadFrom(buffer); err != nil { var Length VarInt
return fmt.Errorf("read packet id fail: %v", err) _, err := Length.ReadFrom(r)
} if err != nil {
p.ID = int32(packetID)
p.Data = buffer.Bytes()
return nil
}
// unCompress 读取一个压缩的包
func unCompress(data *bytes.Buffer) error {
reader := bytes.NewReader(data.Bytes())
var sizeUncompressed VarInt
if _, err := sizeUncompressed.ReadFrom(reader); err != nil {
return err return err
} }
var uncompressedData []byte var PacketID VarInt
if sizeUncompressed == 0 { n, err := PacketID.ReadFrom(r)
uncompressedData = data.Bytes()[1:]
} else { // != 0 means compressed, let's decompress
uncompressedData = make([]byte, sizeUncompressed)
r, err := zlib.NewReader(reader)
if err != nil { if err != nil {
return fmt.Errorf("decompress fail: %v", err) return err
} }
defer r.Close() p.ID = int32(PacketID)
_, err = io.ReadFull(r, uncompressedData)
lengthOfData := int(Length) - int(n)
if cap(p.Data) < lengthOfData {
p.Data = make([]byte, lengthOfData)
} else {
p.Data = p.Data[:lengthOfData]
}
_, err = io.ReadFull(r, p.Data)
if err != nil { if err != nil {
return fmt.Errorf("decompress fail: %v", err) return err
}
return nil
}
func (p *Packet) unpackWithCompression(r io.Reader) error {
var PacketLength VarInt
_, err := PacketLength.ReadFrom(r)
if err != nil {
return err
}
var DataLength VarInt
n2, err := DataLength.ReadFrom(r)
if err != nil {
return err
}
var PacketID VarInt
if DataLength != 0 {
r, err = zlib.NewReader(r)
if err != nil {
return err
}
_, err = PacketID.ReadFrom(r)
if err != nil {
return err
}
} else {
n3, err := PacketID.ReadFrom(r)
if err != nil {
return err
}
DataLength = PacketLength - VarInt(n2) - VarInt(n3)
}
if cap(p.Data) < int(DataLength) {
p.Data = make([]byte, DataLength)
} else {
p.Data = p.Data[:DataLength]
}
p.ID = int32(PacketID)
_, err = io.ReadFull(r, p.Data)
if err != nil {
return err
} }
}
*data = *bytes.NewBuffer(uncompressedData)
return nil return nil
} }