From 27a30efe7b6f449b258aa3bd9506fa4e5e9fc23e Mon Sep 17 00:00:00 2001 From: Tnze Date: Thu, 25 Feb 2021 14:24:28 +0800 Subject: [PATCH] Change conn.ReadPacket for reuse of the buffer --- README.md | 2 +- bot/ingame.go | 4 +-- bot/mcbot.go | 3 +- bot/pinglist.go | 11 +++---- cmd/simpleServer/main.go | 11 ++++--- cmd/simpleServer/status.go | 3 +- net/conn.go | 8 ++--- net/packet/packet.go | 64 +++++++++++++++++--------------------- 8 files changed, 48 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 433f6d8..8be59e0 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ Then you can send it to server using `conn.WritePacket(p)`. The `conn` is a `net Receiving packet is quite easy too. To read a packet, call `p.Scan()` like this: -接收包也非常简单,只要调用`conn.ReadPacket()`即可。而要读取包内数据则需要使用`p.Scan()`函数,就像这样: +接收包也非常简单,只要调用`conn.ReadPacket(&p)`即可。而要读取包内数据则需要使用`p.Scan()`函数,就像这样: ```go var ( diff --git a/bot/ingame.go b/bot/ingame.go index 937cf27..a128782 100644 --- a/bot/ingame.go +++ b/bot/ingame.go @@ -71,9 +71,9 @@ func (c *Client) HandleGame() error { return default: + var p pk.Packet //Read packets - p, err := c.conn.ReadPacket() - if err != nil { + if err := c.conn.ReadPacket(&p); err != nil { return } c.inbound <- p diff --git a/bot/mcbot.go b/bot/mcbot.go index 661783e..8b15972 100644 --- a/bot/mcbot.go +++ b/bot/mcbot.go @@ -106,8 +106,7 @@ func (c *Client) join(d *net.Dialer, addr string) (err error) { for { //Recive Packet var pack pk.Packet - pack, err = c.conn.ReadPacket() - if err != nil { + if err = c.conn.ReadPacket(&pack); err != nil { err = fmt.Errorf("bot: recv packet for Login fail: %v", err) return } diff --git a/bot/pinglist.go b/bot/pinglist.go index d0afd5b..1fb08b9 100644 --- a/bot/pinglist.go +++ b/bot/pinglist.go @@ -59,13 +59,13 @@ func pingAndList(addr string, port int, conn *net.Conn) ([]byte, time.Duration, return nil, 0, fmt.Errorf("bot: send list packect fail: %v", err) } + var p pk.Packet //服务器返回状态 - recv, err := conn.ReadPacket() - if err != nil { + if err := conn.ReadPacket(&p); err != nil { return nil, 0, fmt.Errorf("bot: recv list packect fail: %v", err) } var s pk.String - err = recv.Scan(&s) + err = p.Scan(&s) if err != nil { return nil, 0, fmt.Errorf("bot: scan list packect fail: %v", err) } @@ -77,12 +77,11 @@ func pingAndList(addr string, port int, conn *net.Conn) ([]byte, time.Duration, return nil, 0, fmt.Errorf("bot: send ping packect fail: %v", err) } - recv, err = conn.ReadPacket() - if err != nil { + if err = conn.ReadPacket(&p); err != nil { return nil, 0, fmt.Errorf("bot: recv pong packect fail: %v", err) } var t pk.Long - err = recv.Scan(&t) + err = p.Scan(&t) if err != nil { return nil, 0, fmt.Errorf("bot: scan pong packect fail: %v", err) } diff --git a/cmd/simpleServer/main.go b/cmd/simpleServer/main.go index 82cfceb..b73be24 100644 --- a/cmd/simpleServer/main.go +++ b/cmd/simpleServer/main.go @@ -83,7 +83,8 @@ func handlePlaying(conn net.Conn, protocol int32) { } // Just for block this goroutine. Keep the connection for { - if _, err := conn.ReadPacket(); err != nil { + var p pk.Packet + if err := conn.ReadPacket(&p); err != nil { log.Printf("ReadPacket error: %v", err) break } @@ -102,7 +103,7 @@ type PlayerInfo struct { func acceptLogin(conn net.Conn) (info PlayerInfo, err error) { //login start var p pk.Packet - p, err = conn.ReadPacket() + err = conn.ReadPacket(&p) if err != nil { return } @@ -127,14 +128,14 @@ func acceptLogin(conn net.Conn) (info PlayerInfo, err error) { // handshake receive and parse Handshake packet func handshake(conn net.Conn) (protocol, intention int32, err error) { var ( + p pk.Packet Protocol, Intention pk.VarInt ServerAddress pk.String // ignored ServerPort pk.UnsignedShort // ignored ) // receive handshake packet - p, err := conn.ReadPacket() - if err != nil { - return 0, 0, err + if err = conn.ReadPacket(&p); err != nil { + return } err = p.Scan(&Protocol, &ServerAddress, &ServerPort, &Intention) return int32(Protocol), int32(Intention), err diff --git a/cmd/simpleServer/status.go b/cmd/simpleServer/status.go index 9cb2831..6a02b67 100644 --- a/cmd/simpleServer/status.go +++ b/cmd/simpleServer/status.go @@ -10,8 +10,9 @@ import ( ) func acceptListPing(conn net.Conn) { + var p pk.Packet for i := 0; i < 2; i++ { // ping or list. Only accept twice - p, err := conn.ReadPacket() + err := conn.ReadPacket(&p) if err != nil { return } diff --git a/net/conn.go b/net/conn.go index c2927a7..137fec4 100644 --- a/net/conn.go +++ b/net/conn.go @@ -79,12 +79,8 @@ func WrapConn(conn net.Conn) *Conn { func (c *Conn) Close() error { return c.Socket.Close() } // ReadPacket read a Packet from Conn. -func (c *Conn) ReadPacket() (pk.Packet, error) { - p, err := pk.RecvPacket(c.Reader, c.threshold > 0) - if err != nil { - return pk.Packet{}, err - } - return *p, err +func (c *Conn) ReadPacket(p *pk.Packet) error { + return p.UnPack(c.Reader, c.threshold) } //WritePacket write a Packet to Conn. diff --git a/net/packet/packet.go b/net/packet/packet.go index bedfaac..9253110 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -43,7 +43,7 @@ func (p *Packet) Pack(threshold int) (pack []byte) { if len(d) > threshold { //是否需要压缩 Len := len(d) VarLen := VarInt(Len).Encode() - d = Compress(d) + d = compress(d) pack = append(pack, VarInt(len(VarLen)+len(d)).Encode()...) pack = append(pack, VarLen...) @@ -61,73 +61,67 @@ func (p *Packet) Pack(threshold int) (pack []byte) { return } -// RecvPacket receive a packet from server -func RecvPacket(r DecodeReader, useZlib bool) (*Packet, error) { +// UnPack in-place decompression a packet +func (p *Packet) UnPack(r DecodeReader, threshold int) error { var length VarInt if err := length.Decode(r); err != nil { - return nil, err + return err } if length < 1 { - return nil, fmt.Errorf("packet length too short") + return fmt.Errorf("packet length too short") } - d := make([]byte, length) // read packet content - if _, err := io.ReadFull(r, d); err != nil { - return nil, fmt.Errorf("read content of packet fail: %v", err) + buf := bytes.NewBuffer(p.Data[:0]) + if _, err := io.CopyN(buf, r, int64(length)); err != nil { + return fmt.Errorf("read content of packet fail: %w", err) } //解压数据 - if useZlib { - return UnCompress(d) + if threshold > 0 { + if err := unCompress(buf); err != nil { + return err + } } - buf := bytes.NewBuffer(d) var packetID VarInt if err := packetID.Decode(buf); err != nil { - return nil, fmt.Errorf("read packet id fail: %v", err) + return fmt.Errorf("read packet id fail: %v", err) } - return &Packet{ - ID: int32(packetID), - Data: buf.Bytes(), - }, nil + p.ID = int32(packetID) + p.Data = buf.Bytes() + return nil } -// UnCompress 读取一个压缩的包 -func UnCompress(data []byte) (*Packet, error) { - reader := bytes.NewReader(data) +// unCompress 读取一个压缩的包 +func unCompress(data *bytes.Buffer) error { + reader := bytes.NewReader(data.Bytes()) var sizeUncompressed VarInt if err := sizeUncompressed.Decode(reader); err != nil { - return nil, err + return err } - uncompressedData := make([]byte, sizeUncompressed) + var uncompressedData []byte if sizeUncompressed != 0 { // != 0 means compressed, let's decompress + uncompressedData = make([]byte, sizeUncompressed) r, err := zlib.NewReader(reader) if err != nil { - return nil, fmt.Errorf("decompress fail: %v", err) + return fmt.Errorf("decompress fail: %v", err) } defer r.Close() _, err = io.ReadFull(r, uncompressedData) if err != nil { - return nil, fmt.Errorf("decompress fail: %v", err) + return fmt.Errorf("decompress fail: %v", err) } } else { - uncompressedData = data[1:] + uncompressedData = data.Bytes()[1:] } - buf := bytes.NewBuffer(uncompressedData) - var packetID VarInt - if err := packetID.Decode(buf); err != nil { - return nil, fmt.Errorf("read packet id fail: %v", err) - } - return &Packet{ - ID: int32(packetID), - Data: buf.Bytes(), - }, nil + *data = *bytes.NewBuffer(uncompressedData) + return nil } -// Compress 压缩数据 -func Compress(data []byte) []byte { +// compress 压缩数据 +func compress(data []byte) []byte { var b bytes.Buffer w := zlib.NewWriter(&b) if _, err := w.Write(data); err != nil {