Add buffer pool
This commit is contained in:
@ -3,8 +3,8 @@ package packet
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/zlib"
|
"compress/zlib"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Packet define a net data package
|
// Packet define a net data package
|
||||||
@ -34,18 +34,24 @@ func (p Packet) Scan(fields ...FieldDecoder) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var bufPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new(bytes.Buffer)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Pack 打包一个数据包
|
// Pack 打包一个数据包
|
||||||
func (p *Packet) Pack(w io.Writer, threshold int) error {
|
func (p *Packet) Pack(w io.Writer, threshold int) error {
|
||||||
if threshold >= 0 {
|
if threshold >= 0 {
|
||||||
return p.withCompression(w)
|
return p.packWithCompression(w)
|
||||||
} else {
|
} else {
|
||||||
return p.withoutCompression(w)
|
return p.packWithoutCompression(w)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) withoutCompression(w io.Writer) error {
|
func (p *Packet) packWithoutCompression(w io.Writer) error {
|
||||||
var buf [5]byte
|
buffer := bufPool.Get().(*bytes.Buffer)
|
||||||
buffer := bytes.NewBuffer(buf[:0])
|
defer bufPool.Put(buffer)
|
||||||
n, err := VarInt(p.ID).WriteTo(buffer)
|
n, err := VarInt(p.ID).WriteTo(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@ -68,9 +74,10 @@ func (p *Packet) withoutCompression(w io.Writer) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Packet) withCompression(w io.Writer) error {
|
func (p *Packet) packWithCompression(w io.Writer) error {
|
||||||
var buff bytes.Buffer
|
buff := bufPool.Get().(*bytes.Buffer)
|
||||||
zw := zlib.NewWriter(&buff)
|
defer bufPool.Put(buff)
|
||||||
|
zw := zlib.NewWriter(buff)
|
||||||
n1, err := VarInt(p.ID).WriteTo(zw)
|
n1, err := VarInt(p.ID).WriteTo(zw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -84,8 +91,9 @@ func (p *Packet) withCompression(w io.Writer) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var dataLength bytes.Buffer
|
dataLength := bufPool.Get().(*bytes.Buffer)
|
||||||
n3, err := VarInt(int(n1) + n2).WriteTo(&dataLength)
|
defer bufPool.Put(dataLength)
|
||||||
|
n3, err := VarInt(int(n1) + n2).WriteTo(dataLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -110,59 +118,79 @@ func (p *Packet) withCompression(w io.Writer) error {
|
|||||||
|
|
||||||
// UnPack in-place decompression a packet
|
// UnPack in-place decompression a packet
|
||||||
func (p *Packet) UnPack(r io.Reader, threshold int) error {
|
func (p *Packet) UnPack(r io.Reader, threshold int) error {
|
||||||
var length VarInt
|
|
||||||
if _, err := length.ReadFrom(r); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if length < 1 {
|
|
||||||
return fmt.Errorf("packet length too short")
|
|
||||||
}
|
|
||||||
buf := make([]byte, length)
|
|
||||||
if _, err := io.ReadFull(r, buf); err != nil {
|
|
||||||
return fmt.Errorf("read content of packet fail: %w", err)
|
|
||||||
}
|
|
||||||
buffer := bytes.NewBuffer(buf)
|
|
||||||
|
|
||||||
//解压数据
|
|
||||||
if threshold >= 0 {
|
if threshold >= 0 {
|
||||||
if err := unCompress(buffer); err != nil {
|
return p.unpackWithCompression(r)
|
||||||
return err
|
} else {
|
||||||
|
return p.unpackWithoutCompression(r)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
var packetID VarInt
|
|
||||||
if _, err := packetID.ReadFrom(buffer); err != nil {
|
|
||||||
return fmt.Errorf("read packet id fail: %v", err)
|
|
||||||
}
|
|
||||||
p.ID = int32(packetID)
|
|
||||||
p.Data = buffer.Bytes()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// unCompress 读取一个压缩的包
|
func (p *Packet) unpackWithoutCompression(r io.Reader) error {
|
||||||
func unCompress(data *bytes.Buffer) error {
|
var Length VarInt
|
||||||
reader := bytes.NewReader(data.Bytes())
|
_, err := Length.ReadFrom(r)
|
||||||
|
if err != nil {
|
||||||
var sizeUncompressed VarInt
|
|
||||||
if _, err := sizeUncompressed.ReadFrom(reader); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var uncompressedData []byte
|
var PacketID VarInt
|
||||||
if sizeUncompressed == 0 {
|
n, err := PacketID.ReadFrom(r)
|
||||||
uncompressedData = data.Bytes()[1:]
|
|
||||||
} else { // != 0 means compressed, let's decompress
|
|
||||||
uncompressedData = make([]byte, sizeUncompressed)
|
|
||||||
r, err := zlib.NewReader(reader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("decompress fail: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
p.ID = int32(PacketID)
|
||||||
_, err = io.ReadFull(r, uncompressedData)
|
|
||||||
|
lengthOfData := int(Length) - int(n)
|
||||||
|
if cap(p.Data) < lengthOfData {
|
||||||
|
p.Data = make([]byte, lengthOfData)
|
||||||
|
} else {
|
||||||
|
p.Data = p.Data[:lengthOfData]
|
||||||
|
}
|
||||||
|
_, err = io.ReadFull(r, p.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("decompress fail: %v", err)
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) unpackWithCompression(r io.Reader) error {
|
||||||
|
var PacketLength VarInt
|
||||||
|
_, err := PacketLength.ReadFrom(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var DataLength VarInt
|
||||||
|
n2, err := DataLength.ReadFrom(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var PacketID VarInt
|
||||||
|
if DataLength != 0 {
|
||||||
|
r, err = zlib.NewReader(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = PacketID.ReadFrom(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n3, err := PacketID.ReadFrom(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
DataLength = PacketLength - VarInt(n2) - VarInt(n3)
|
||||||
|
}
|
||||||
|
if cap(p.Data) < int(DataLength) {
|
||||||
|
p.Data = make([]byte, DataLength)
|
||||||
|
} else {
|
||||||
|
p.Data = p.Data[:DataLength]
|
||||||
|
}
|
||||||
|
p.ID = int32(PacketID)
|
||||||
|
_, err = io.ReadFull(r, p.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
*data = *bytes.NewBuffer(uncompressedData)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user