Use packet queue for bot framework.

Providing concurrently-safing for sending packets.
This commit is contained in:
Tnze
2023-03-05 10:19:18 +08:00
parent 90501f1357
commit 925b1359fc
5 changed files with 85 additions and 14 deletions

View File

@ -51,7 +51,7 @@ func (p *Player) Respawn() error {
const PerformRespawn = 0
err := p.c.Conn.WritePacket(pk.Marshal(
int32(packetid.ServerboundClientCommand),
packetid.ServerboundClientCommand,
pk.VarInt(PerformRespawn),
))
if err != nil {

View File

@ -54,7 +54,7 @@ func (p *Player) handleLoginPacket(packet pk.Packet) error {
return Error{err}
}
err = p.c.Conn.WritePacket(pk.Marshal( // PluginMessage packet
int32(packetid.ServerboundCustomPayload),
packetid.ServerboundCustomPayload,
pk.Identifier("minecraft:brand"),
pk.String(p.Settings.Brand),
))
@ -63,7 +63,7 @@ func (p *Player) handleLoginPacket(packet pk.Packet) error {
}
err = p.c.Conn.WritePacket(pk.Marshal(
int32(packetid.ServerboundClientInformation), // Client settings
packetid.ServerboundClientInformation, // Client settings
pk.String(p.Settings.Locale),
pk.Byte(p.Settings.ViewDistance),
pk.VarInt(p.Settings.ChatMode),

View File

@ -1,15 +1,19 @@
package bot
import (
"errors"
"github.com/google/uuid"
"github.com/Tnze/go-mc/data/packetid"
"github.com/Tnze/go-mc/net"
pk "github.com/Tnze/go-mc/net/packet"
"github.com/Tnze/go-mc/net/queue"
)
// Client is used to access Minecraft server
type Client struct {
Conn *net.Conn
Conn *Conn
Auth Auth
Name string
@ -37,6 +41,71 @@ func NewClient() *Client {
}
}
// Conn is a concurrently-safe warpper of net.Conn with packet queue.
// Note that not all methods are concurrently-safe.
type Conn struct {
conn *net.Conn
send, recv queue.Queue[pk.Packet]
rerr error
}
func warpConn(c *net.Conn) *Conn {
wc := Conn{
conn: c,
send: make(queue.ChannelQueue[pk.Packet], 256),
recv: make(queue.ChannelQueue[pk.Packet], 256),
rerr: nil,
}
go func() {
for {
var p pk.Packet
if err := c.ReadPacket(&p); err != nil {
wc.rerr = err
break
}
if ok := wc.recv.Push(p); !ok {
break
}
}
wc.recv.Close()
}()
go func() {
for {
p, ok := wc.send.Pull()
if !ok {
break
}
if err := c.WritePacket(p); err != nil {
break
}
}
}()
return &wc
}
func (c *Conn) ReadPacket(p *pk.Packet) error {
packet, ok := c.recv.Pull()
if !ok {
return c.rerr
}
*p = packet
return nil
}
func (c *Conn) WritePacket(p pk.Packet) error {
ok := c.send.Push(p)
if !ok {
return errors.New("queue is full")
}
return nil
}
func (c *Conn) Close() error {
c.send.Close()
return c.conn.Close()
}
// Position is a 3D vector.
type Position struct {
X, Y, Z int

View File

@ -15,6 +15,7 @@ import (
"strings"
"github.com/Tnze/go-mc/data/packetid"
"github.com/Tnze/go-mc/net"
"github.com/Tnze/go-mc/net/CFB8"
pk "github.com/Tnze/go-mc/net/packet"
)
@ -26,7 +27,7 @@ type Auth struct {
AsTk string
}
func handleEncryptionRequest(c *Client, p pk.Packet) error {
func handleEncryptionRequest(conn *net.Conn, c *Client, p pk.Packet) error {
// 创建AES对称加密密钥
key, encoStream, decoStream := newSymmetricEncryption()
@ -48,13 +49,13 @@ func handleEncryptionRequest(c *Client, p pk.Packet) error {
return fmt.Errorf("gen encryption key response fail: %v", err)
}
err = c.Conn.WritePacket(p)
err = conn.WritePacket(p)
if err != nil {
return err
}
// 设置连接加密
c.Conn.SetCipher(encoStream, decoStream)
conn.SetCipher(encoStream, decoStream)
return nil
}

View File

@ -88,13 +88,13 @@ func (c *Client) join(addr string, options JoinOptions) error {
}
// Dial connection
c.Conn, err = options.MCDialer.DialMCContext(options.Context, addr)
conn, err := options.MCDialer.DialMCContext(options.Context, addr)
if err != nil {
return LoginErr{"connect server", err}
}
// Handshake
err = c.Conn.WritePacket(pk.Marshal(
err = conn.WritePacket(pk.Marshal(
Handshake,
pk.VarInt(ProtocolVersion), // Protocol version
pk.String(host), // Host
@ -110,7 +110,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
Has: err == nil,
Val: pk.UUID(c.UUID),
}
err = c.Conn.WritePacket(pk.Marshal(
err = conn.WritePacket(pk.Marshal(
packetid.LoginStart,
pk.String(c.Auth.Name),
PlayerUUID,
@ -122,7 +122,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
for {
// Receive Packet
var p pk.Packet
if err = c.Conn.ReadPacket(&p); err != nil {
if err = conn.ReadPacket(&p); err != nil {
return LoginErr{receiving, err}
}
@ -137,7 +137,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
return LoginErr{"disconnect", DisconnectErr(reason)}
case packetid.LoginEncryptionRequest: // Encryption Request
if err := handleEncryptionRequest(c, p); err != nil {
if err := handleEncryptionRequest(conn, c, p); err != nil {
return LoginErr{"encryption", err}
}
receiving = "set compression"
@ -150,6 +150,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
if err != nil {
return LoginErr{"login success", err}
}
c.Conn = warpConn(conn)
return nil
case packetid.LoginCompression: // Set Compression
@ -157,7 +158,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
if err := p.Scan(&threshold); err != nil {
return LoginErr{"compression", err}
}
c.Conn.SetThreshold(int(threshold))
conn.SetThreshold(int(threshold))
receiving = "login success"
case packetid.LoginPluginRequest: // Login Plugin Request
@ -179,7 +180,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
}
}
if err := c.Conn.WritePacket(pk.Marshal(
if err := conn.WritePacket(pk.Marshal(
packetid.LoginPluginResponse,
msgid, PluginMessageData,
)); err != nil {