Add buffer pool
This commit is contained in:
@ -3,8 +3,8 @@ package packet
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Packet define a net data package
|
||||
@ -34,18 +34,24 @@ func (p Packet) Scan(fields ...FieldDecoder) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
}
|
||||
|
||||
// Pack 打包一个数据包
|
||||
func (p *Packet) Pack(w io.Writer, threshold int) error {
|
||||
if threshold >= 0 {
|
||||
return p.withCompression(w)
|
||||
return p.packWithCompression(w)
|
||||
} else {
|
||||
return p.withoutCompression(w)
|
||||
return p.packWithoutCompression(w)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Packet) withoutCompression(w io.Writer) error {
|
||||
var buf [5]byte
|
||||
buffer := bytes.NewBuffer(buf[:0])
|
||||
func (p *Packet) packWithoutCompression(w io.Writer) error {
|
||||
buffer := bufPool.Get().(*bytes.Buffer)
|
||||
defer bufPool.Put(buffer)
|
||||
n, err := VarInt(p.ID).WriteTo(buffer)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -68,9 +74,10 @@ func (p *Packet) withoutCompression(w io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Packet) withCompression(w io.Writer) error {
|
||||
var buff bytes.Buffer
|
||||
zw := zlib.NewWriter(&buff)
|
||||
func (p *Packet) packWithCompression(w io.Writer) error {
|
||||
buff := bufPool.Get().(*bytes.Buffer)
|
||||
defer bufPool.Put(buff)
|
||||
zw := zlib.NewWriter(buff)
|
||||
n1, err := VarInt(p.ID).WriteTo(zw)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -84,8 +91,9 @@ func (p *Packet) withCompression(w io.Writer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var dataLength bytes.Buffer
|
||||
n3, err := VarInt(int(n1) + n2).WriteTo(&dataLength)
|
||||
dataLength := bufPool.Get().(*bytes.Buffer)
|
||||
defer bufPool.Put(dataLength)
|
||||
n3, err := VarInt(int(n1) + n2).WriteTo(dataLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -110,59 +118,79 @@ func (p *Packet) withCompression(w io.Writer) error {
|
||||
|
||||
// UnPack in-place decompression a packet
|
||||
func (p *Packet) UnPack(r io.Reader, threshold int) error {
|
||||
var length VarInt
|
||||
if _, err := length.ReadFrom(r); err != nil {
|
||||
return err
|
||||
}
|
||||
if length < 1 {
|
||||
return fmt.Errorf("packet length too short")
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return fmt.Errorf("read content of packet fail: %w", err)
|
||||
}
|
||||
buffer := bytes.NewBuffer(buf)
|
||||
|
||||
//解压数据
|
||||
if threshold >= 0 {
|
||||
if err := unCompress(buffer); err != nil {
|
||||
return err
|
||||
return p.unpackWithCompression(r)
|
||||
} else {
|
||||
return p.unpackWithoutCompression(r)
|
||||
}
|
||||
}
|
||||
|
||||
var packetID VarInt
|
||||
if _, err := packetID.ReadFrom(buffer); err != nil {
|
||||
return fmt.Errorf("read packet id fail: %v", err)
|
||||
}
|
||||
p.ID = int32(packetID)
|
||||
p.Data = buffer.Bytes()
|
||||
return nil
|
||||
}
|
||||
|
||||
// unCompress 读取一个压缩的包
|
||||
func unCompress(data *bytes.Buffer) error {
|
||||
reader := bytes.NewReader(data.Bytes())
|
||||
|
||||
var sizeUncompressed VarInt
|
||||
if _, err := sizeUncompressed.ReadFrom(reader); err != nil {
|
||||
func (p *Packet) unpackWithoutCompression(r io.Reader) error {
|
||||
var Length VarInt
|
||||
_, err := Length.ReadFrom(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var uncompressedData []byte
|
||||
if sizeUncompressed == 0 {
|
||||
uncompressedData = data.Bytes()[1:]
|
||||
} else { // != 0 means compressed, let's decompress
|
||||
uncompressedData = make([]byte, sizeUncompressed)
|
||||
r, err := zlib.NewReader(reader)
|
||||
var PacketID VarInt
|
||||
n, err := PacketID.ReadFrom(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decompress fail: %v", err)
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
_, err = io.ReadFull(r, uncompressedData)
|
||||
p.ID = int32(PacketID)
|
||||
|
||||
lengthOfData := int(Length) - int(n)
|
||||
if cap(p.Data) < lengthOfData {
|
||||
p.Data = make([]byte, lengthOfData)
|
||||
} else {
|
||||
p.Data = p.Data[:lengthOfData]
|
||||
}
|
||||
_, err = io.ReadFull(r, p.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decompress fail: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Packet) unpackWithCompression(r io.Reader) error {
|
||||
var PacketLength VarInt
|
||||
_, err := PacketLength.ReadFrom(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var DataLength VarInt
|
||||
n2, err := DataLength.ReadFrom(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var PacketID VarInt
|
||||
if DataLength != 0 {
|
||||
r, err = zlib.NewReader(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = PacketID.ReadFrom(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
n3, err := PacketID.ReadFrom(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
DataLength = PacketLength - VarInt(n2) - VarInt(n3)
|
||||
}
|
||||
if cap(p.Data) < int(DataLength) {
|
||||
p.Data = make([]byte, DataLength)
|
||||
} else {
|
||||
p.Data = p.Data[:DataLength]
|
||||
}
|
||||
p.ID = int32(PacketID)
|
||||
_, err = io.ReadFull(r, p.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*data = *bytes.NewBuffer(uncompressedData)
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user