From b34044cff9e87d527698f833e47ffdcddd799105 Mon Sep 17 00:00:00 2001 From: Tnze Date: Tue, 30 Nov 2021 16:12:12 +0800 Subject: [PATCH] fix #132 --- bot/screen/screen.go | 9 ++- net/packet/packet.go | 129 +++++++++++++++++++++++++++++-------------- 2 files changed, 96 insertions(+), 42 deletions(-) diff --git a/bot/screen/screen.go b/bot/screen/screen.go index c3c31d5..83c1624 100644 --- a/bot/screen/screen.go +++ b/bot/screen/screen.go @@ -56,15 +56,20 @@ func (m *Manager) onOpenScreen(p pk.Packet) error { func (m *Manager) onSetContentPacket(p pk.Packet) error { var ( ContainerID pk.UnsignedByte - Count pk.Short + StateID pk.VarInt + Count pk.VarInt SlotData []Slot + CarriedItem Slot ) if err := p.Scan( &ContainerID, + &StateID, &Count, pk.Ary{ Len: &Count, Ary: &SlotData, - }); err != nil { + }, + &CarriedItem, + ); err != nil { return Error{err} } // copy the slot data to container diff --git a/net/packet/packet.go b/net/packet/packet.go index 9737ef2..80af644 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -3,6 +3,7 @@ package packet import ( "bytes" "compress/zlib" + "fmt" "io" "sync" ) @@ -43,7 +44,7 @@ var bufPool = sync.Pool{ // Pack 打包一个数据包 func (p *Packet) Pack(w io.Writer, threshold int) error { if threshold >= 0 { - return p.packWithCompression(w) + return p.packWithCompression(w, threshold) } else { return p.packWithoutCompression(w) } @@ -52,6 +53,7 @@ func (p *Packet) Pack(w io.Writer, threshold int) error { func (p *Packet) packWithoutCompression(w io.Writer) error { buffer := bufPool.Get().(*bytes.Buffer) defer bufPool.Put(buffer) + buffer.Reset() n, err := VarInt(p.ID).WriteTo(buffer) if err != nil { panic(err) @@ -74,44 +76,71 @@ func (p *Packet) packWithoutCompression(w io.Writer) error { return nil } -func (p *Packet) packWithCompression(w io.Writer) error { +func (p *Packet) packWithCompression(w io.Writer, threshold int) error { buff := bufPool.Get().(*bytes.Buffer) defer bufPool.Put(buff) - zw := zlib.NewWriter(buff) - n1, err := VarInt(p.ID).WriteTo(zw) - if err != nil { - return err - } - n2, err := zw.Write(p.Data) - if err != nil { - return err - } - err = zw.Close() - if err != nil { - return err - } + buff.Reset() - dataLength := bufPool.Get().(*bytes.Buffer) - defer bufPool.Put(dataLength) - n3, err := VarInt(int(n1) + n2).WriteTo(dataLength) - if err != nil { - return err - } + if len(p.Data) < threshold { + _, err := VarInt(0).WriteTo(buff) + if err != nil { + return err + } + _, err = VarInt(p.ID).WriteTo(buff) + if err != nil { + return err + } + _, err = buff.Write(p.Data) + if err != nil { + return err + } + // Packet Length + _, err = VarInt(buff.Len()).WriteTo(w) + if err != nil { + return err + } + // Data Length + Packet ID + Data + _, err = buff.WriteTo(w) + if err != nil { + return err + } + } else { + zw := zlib.NewWriter(buff) + n1, err := VarInt(p.ID).WriteTo(zw) + if err != nil { + return err + } + n2, err := zw.Write(p.Data) + if err != nil { + return err + } + err = zw.Close() + if err != nil { + return err + } - // Packet Length - _, err = VarInt(int(n3) + buff.Len()).WriteTo(w) - if err != nil { - return err - } - // Data Length - _, err = dataLength.WriteTo(w) - if err != nil { - return err - } - // PacketID + Data - _, err = buff.WriteTo(w) - if err != nil { - return err + dataLength := bufPool.Get().(*bytes.Buffer) + defer bufPool.Put(dataLength) + n3, err := VarInt(int(n1) + n2).WriteTo(dataLength) + if err != nil { + return err + } + + // Packet Length + _, err = VarInt(int(n3) + buff.Len()).WriteTo(w) + if err != nil { + return err + } + // Data Length + _, err = dataLength.WriteTo(w) + if err != nil { + return err + } + // PacketID + Data + _, err = buff.WriteTo(w) + if err != nil { + return err + } } return nil } @@ -119,7 +148,7 @@ func (p *Packet) packWithCompression(w io.Writer) error { // UnPack in-place decompression a packet func (p *Packet) UnPack(r io.Reader, threshold int) error { if threshold >= 0 { - return p.unpackWithCompression(r) + return p.unpackWithCompression(r, threshold) } else { return p.unpackWithoutCompression(r) } @@ -152,13 +181,23 @@ func (p *Packet) unpackWithoutCompression(r io.Reader) error { return nil } -func (p *Packet) unpackWithCompression(r io.Reader) error { +func (p *Packet) unpackWithCompression(r io.Reader, threshold int) error { var PacketLength VarInt _, err := PacketLength.ReadFrom(r) if err != nil { return err } + buff := bufPool.Get().(*bytes.Buffer) + defer bufPool.Put(buff) + buff.Reset() + + _, err = io.CopyN(buff, r, int64(PacketLength)) + if err != nil { + return err + } + r = bytes.NewReader(buff.Bytes()) + var DataLength VarInt n2, err := DataLength.ReadFrom(r) if err != nil { @@ -167,20 +206,30 @@ func (p *Packet) unpackWithCompression(r io.Reader) error { var PacketID VarInt if DataLength != 0 { - r, err = zlib.NewReader(r) + if int(DataLength) < threshold { + return fmt.Errorf("compressed packet error: size of %d is below threshold of %d", DataLength, threshold) + } + const MaxDataLength = 2097152 + if DataLength > MaxDataLength { + return fmt.Errorf("compressed packet error: size of %d is larger than protocol maximum of %d", DataLength, MaxDataLength) + } + zr, err := zlib.NewReader(r) if err != nil { return err } - _, err = PacketID.ReadFrom(r) + defer zr.Close() + r = zr + n3, err := PacketID.ReadFrom(r) if err != nil { return err } + DataLength -= VarInt(n3) } else { n3, err := PacketID.ReadFrom(r) if err != nil { return err } - DataLength = PacketLength - VarInt(n2) - VarInt(n3) + DataLength = VarInt(int64(PacketLength) - n2 - n3) } if cap(p.Data) < int(DataLength) { p.Data = make([]byte, DataLength)