Change conn.ReadPacket for reuse of the buffer

This commit is contained in:
Tnze
2021-02-25 14:24:28 +08:00
parent 5feb25895e
commit 27a30efe7b
8 changed files with 48 additions and 58 deletions

View File

@ -81,7 +81,7 @@ Then you can send it to server using `conn.WritePacket(p)`. The `conn` is a `net
Receiving packet is quite easy too. To read a packet, call `p.Scan()` like this: Receiving packet is quite easy too. To read a packet, call `p.Scan()` like this:
接收包也非常简单,只要调用`conn.ReadPacket()`即可。而要读取包内数据则需要使用`p.Scan()`函数,就像这样: 接收包也非常简单,只要调用`conn.ReadPacket(&p)`即可。而要读取包内数据则需要使用`p.Scan()`函数,就像这样:
```go ```go
var ( var (

View File

@ -71,9 +71,9 @@ func (c *Client) HandleGame() error {
return return
default: default:
var p pk.Packet
//Read packets //Read packets
p, err := c.conn.ReadPacket() if err := c.conn.ReadPacket(&p); err != nil {
if err != nil {
return return
} }
c.inbound <- p c.inbound <- p

View File

@ -106,8 +106,7 @@ func (c *Client) join(d *net.Dialer, addr string) (err error) {
for { for {
//Recive Packet //Recive Packet
var pack pk.Packet var pack pk.Packet
pack, err = c.conn.ReadPacket() if err = c.conn.ReadPacket(&pack); err != nil {
if err != nil {
err = fmt.Errorf("bot: recv packet for Login fail: %v", err) err = fmt.Errorf("bot: recv packet for Login fail: %v", err)
return return
} }

View File

@ -59,13 +59,13 @@ func pingAndList(addr string, port int, conn *net.Conn) ([]byte, time.Duration,
return nil, 0, fmt.Errorf("bot: send list packect fail: %v", err) return nil, 0, fmt.Errorf("bot: send list packect fail: %v", err)
} }
var p pk.Packet
//服务器返回状态 //服务器返回状态
recv, err := conn.ReadPacket() if err := conn.ReadPacket(&p); err != nil {
if err != nil {
return nil, 0, fmt.Errorf("bot: recv list packect fail: %v", err) return nil, 0, fmt.Errorf("bot: recv list packect fail: %v", err)
} }
var s pk.String var s pk.String
err = recv.Scan(&s) err = p.Scan(&s)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("bot: scan list packect fail: %v", err) return nil, 0, fmt.Errorf("bot: scan list packect fail: %v", err)
} }
@ -77,12 +77,11 @@ func pingAndList(addr string, port int, conn *net.Conn) ([]byte, time.Duration,
return nil, 0, fmt.Errorf("bot: send ping packect fail: %v", err) return nil, 0, fmt.Errorf("bot: send ping packect fail: %v", err)
} }
recv, err = conn.ReadPacket() if err = conn.ReadPacket(&p); err != nil {
if err != nil {
return nil, 0, fmt.Errorf("bot: recv pong packect fail: %v", err) return nil, 0, fmt.Errorf("bot: recv pong packect fail: %v", err)
} }
var t pk.Long var t pk.Long
err = recv.Scan(&t) err = p.Scan(&t)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("bot: scan pong packect fail: %v", err) return nil, 0, fmt.Errorf("bot: scan pong packect fail: %v", err)
} }

View File

@ -83,7 +83,8 @@ func handlePlaying(conn net.Conn, protocol int32) {
} }
// Just for block this goroutine. Keep the connection // Just for block this goroutine. Keep the connection
for { for {
if _, err := conn.ReadPacket(); err != nil { var p pk.Packet
if err := conn.ReadPacket(&p); err != nil {
log.Printf("ReadPacket error: %v", err) log.Printf("ReadPacket error: %v", err)
break break
} }
@ -102,7 +103,7 @@ type PlayerInfo struct {
func acceptLogin(conn net.Conn) (info PlayerInfo, err error) { func acceptLogin(conn net.Conn) (info PlayerInfo, err error) {
//login start //login start
var p pk.Packet var p pk.Packet
p, err = conn.ReadPacket() err = conn.ReadPacket(&p)
if err != nil { if err != nil {
return return
} }
@ -127,14 +128,14 @@ func acceptLogin(conn net.Conn) (info PlayerInfo, err error) {
// handshake receive and parse Handshake packet // handshake receive and parse Handshake packet
func handshake(conn net.Conn) (protocol, intention int32, err error) { func handshake(conn net.Conn) (protocol, intention int32, err error) {
var ( var (
p pk.Packet
Protocol, Intention pk.VarInt Protocol, Intention pk.VarInt
ServerAddress pk.String // ignored ServerAddress pk.String // ignored
ServerPort pk.UnsignedShort // ignored ServerPort pk.UnsignedShort // ignored
) )
// receive handshake packet // receive handshake packet
p, err := conn.ReadPacket() if err = conn.ReadPacket(&p); err != nil {
if err != nil { return
return 0, 0, err
} }
err = p.Scan(&Protocol, &ServerAddress, &ServerPort, &Intention) err = p.Scan(&Protocol, &ServerAddress, &ServerPort, &Intention)
return int32(Protocol), int32(Intention), err return int32(Protocol), int32(Intention), err

View File

@ -10,8 +10,9 @@ import (
) )
func acceptListPing(conn net.Conn) { func acceptListPing(conn net.Conn) {
var p pk.Packet
for i := 0; i < 2; i++ { // ping or list. Only accept twice for i := 0; i < 2; i++ { // ping or list. Only accept twice
p, err := conn.ReadPacket() err := conn.ReadPacket(&p)
if err != nil { if err != nil {
return return
} }

View File

@ -79,12 +79,8 @@ func WrapConn(conn net.Conn) *Conn {
func (c *Conn) Close() error { return c.Socket.Close() } func (c *Conn) Close() error { return c.Socket.Close() }
// ReadPacket read a Packet from Conn. // ReadPacket read a Packet from Conn.
func (c *Conn) ReadPacket() (pk.Packet, error) { func (c *Conn) ReadPacket(p *pk.Packet) error {
p, err := pk.RecvPacket(c.Reader, c.threshold > 0) return p.UnPack(c.Reader, c.threshold)
if err != nil {
return pk.Packet{}, err
}
return *p, err
} }
//WritePacket write a Packet to Conn. //WritePacket write a Packet to Conn.

View File

@ -43,7 +43,7 @@ func (p *Packet) Pack(threshold int) (pack []byte) {
if len(d) > threshold { //是否需要压缩 if len(d) > threshold { //是否需要压缩
Len := len(d) Len := len(d)
VarLen := VarInt(Len).Encode() VarLen := VarInt(Len).Encode()
d = Compress(d) d = compress(d)
pack = append(pack, VarInt(len(VarLen)+len(d)).Encode()...) pack = append(pack, VarInt(len(VarLen)+len(d)).Encode()...)
pack = append(pack, VarLen...) pack = append(pack, VarLen...)
@ -61,73 +61,67 @@ func (p *Packet) Pack(threshold int) (pack []byte) {
return return
} }
// RecvPacket receive a packet from server // UnPack in-place decompression a packet
func RecvPacket(r DecodeReader, useZlib bool) (*Packet, error) { func (p *Packet) UnPack(r DecodeReader, threshold int) error {
var length VarInt var length VarInt
if err := length.Decode(r); err != nil { if err := length.Decode(r); err != nil {
return nil, err return err
} }
if length < 1 { if length < 1 {
return nil, fmt.Errorf("packet length too short") return fmt.Errorf("packet length too short")
} }
d := make([]byte, length) // read packet content buf := bytes.NewBuffer(p.Data[:0])
if _, err := io.ReadFull(r, d); err != nil { if _, err := io.CopyN(buf, r, int64(length)); err != nil {
return nil, fmt.Errorf("read content of packet fail: %v", err) return fmt.Errorf("read content of packet fail: %w", err)
} }
//解压数据 //解压数据
if useZlib { if threshold > 0 {
return UnCompress(d) if err := unCompress(buf); err != nil {
return err
}
} }
buf := bytes.NewBuffer(d)
var packetID VarInt var packetID VarInt
if err := packetID.Decode(buf); err != nil { if err := packetID.Decode(buf); err != nil {
return nil, fmt.Errorf("read packet id fail: %v", err) return fmt.Errorf("read packet id fail: %v", err)
} }
return &Packet{ p.ID = int32(packetID)
ID: int32(packetID), p.Data = buf.Bytes()
Data: buf.Bytes(), return nil
}, nil
} }
// UnCompress 读取一个压缩的包 // unCompress 读取一个压缩的包
func UnCompress(data []byte) (*Packet, error) { func unCompress(data *bytes.Buffer) error {
reader := bytes.NewReader(data) reader := bytes.NewReader(data.Bytes())
var sizeUncompressed VarInt var sizeUncompressed VarInt
if err := sizeUncompressed.Decode(reader); err != nil { if err := sizeUncompressed.Decode(reader); err != nil {
return nil, err return err
} }
uncompressedData := make([]byte, sizeUncompressed) var uncompressedData []byte
if sizeUncompressed != 0 { // != 0 means compressed, let's decompress if sizeUncompressed != 0 { // != 0 means compressed, let's decompress
uncompressedData = make([]byte, sizeUncompressed)
r, err := zlib.NewReader(reader) r, err := zlib.NewReader(reader)
if err != nil { if err != nil {
return nil, fmt.Errorf("decompress fail: %v", err) return fmt.Errorf("decompress fail: %v", err)
} }
defer r.Close() defer r.Close()
_, err = io.ReadFull(r, uncompressedData) _, err = io.ReadFull(r, uncompressedData)
if err != nil { if err != nil {
return nil, fmt.Errorf("decompress fail: %v", err) return fmt.Errorf("decompress fail: %v", err)
} }
} else { } else {
uncompressedData = data[1:] uncompressedData = data.Bytes()[1:]
} }
buf := bytes.NewBuffer(uncompressedData) *data = *bytes.NewBuffer(uncompressedData)
var packetID VarInt return nil
if err := packetID.Decode(buf); err != nil {
return nil, fmt.Errorf("read packet id fail: %v", err)
}
return &Packet{
ID: int32(packetID),
Data: buf.Bytes(),
}, nil
} }
// Compress 压缩数据 // compress 压缩数据
func Compress(data []byte) []byte { func compress(data []byte) []byte {
var b bytes.Buffer var b bytes.Buffer
w := zlib.NewWriter(&b) w := zlib.NewWriter(&b)
if _, err := w.Write(data); err != nil { if _, err := w.Write(data); err != nil {