Use packet queue for bot framework.

Providing concurrently-safing for sending packets.
This commit is contained in:
Tnze
2023-03-05 10:19:18 +08:00
parent 90501f1357
commit 925b1359fc
5 changed files with 85 additions and 14 deletions

View File

@ -51,7 +51,7 @@ func (p *Player) Respawn() error {
const PerformRespawn = 0 const PerformRespawn = 0
err := p.c.Conn.WritePacket(pk.Marshal( err := p.c.Conn.WritePacket(pk.Marshal(
int32(packetid.ServerboundClientCommand), packetid.ServerboundClientCommand,
pk.VarInt(PerformRespawn), pk.VarInt(PerformRespawn),
)) ))
if err != nil { if err != nil {

View File

@ -54,7 +54,7 @@ func (p *Player) handleLoginPacket(packet pk.Packet) error {
return Error{err} return Error{err}
} }
err = p.c.Conn.WritePacket(pk.Marshal( // PluginMessage packet err = p.c.Conn.WritePacket(pk.Marshal( // PluginMessage packet
int32(packetid.ServerboundCustomPayload), packetid.ServerboundCustomPayload,
pk.Identifier("minecraft:brand"), pk.Identifier("minecraft:brand"),
pk.String(p.Settings.Brand), pk.String(p.Settings.Brand),
)) ))
@ -63,7 +63,7 @@ func (p *Player) handleLoginPacket(packet pk.Packet) error {
} }
err = p.c.Conn.WritePacket(pk.Marshal( err = p.c.Conn.WritePacket(pk.Marshal(
int32(packetid.ServerboundClientInformation), // Client settings packetid.ServerboundClientInformation, // Client settings
pk.String(p.Settings.Locale), pk.String(p.Settings.Locale),
pk.Byte(p.Settings.ViewDistance), pk.Byte(p.Settings.ViewDistance),
pk.VarInt(p.Settings.ChatMode), pk.VarInt(p.Settings.ChatMode),

View File

@ -1,15 +1,19 @@
package bot package bot
import ( import (
"errors"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/Tnze/go-mc/data/packetid" "github.com/Tnze/go-mc/data/packetid"
"github.com/Tnze/go-mc/net" "github.com/Tnze/go-mc/net"
pk "github.com/Tnze/go-mc/net/packet"
"github.com/Tnze/go-mc/net/queue"
) )
// Client is used to access Minecraft server // Client is used to access Minecraft server
type Client struct { type Client struct {
Conn *net.Conn Conn *Conn
Auth Auth Auth Auth
Name string Name string
@ -37,6 +41,71 @@ func NewClient() *Client {
} }
} }
// Conn is a concurrently-safe warpper of net.Conn with packet queue.
// Note that not all methods are concurrently-safe.
type Conn struct {
conn *net.Conn
send, recv queue.Queue[pk.Packet]
rerr error
}
func warpConn(c *net.Conn) *Conn {
wc := Conn{
conn: c,
send: make(queue.ChannelQueue[pk.Packet], 256),
recv: make(queue.ChannelQueue[pk.Packet], 256),
rerr: nil,
}
go func() {
for {
var p pk.Packet
if err := c.ReadPacket(&p); err != nil {
wc.rerr = err
break
}
if ok := wc.recv.Push(p); !ok {
break
}
}
wc.recv.Close()
}()
go func() {
for {
p, ok := wc.send.Pull()
if !ok {
break
}
if err := c.WritePacket(p); err != nil {
break
}
}
}()
return &wc
}
func (c *Conn) ReadPacket(p *pk.Packet) error {
packet, ok := c.recv.Pull()
if !ok {
return c.rerr
}
*p = packet
return nil
}
func (c *Conn) WritePacket(p pk.Packet) error {
ok := c.send.Push(p)
if !ok {
return errors.New("queue is full")
}
return nil
}
func (c *Conn) Close() error {
c.send.Close()
return c.conn.Close()
}
// Position is a 3D vector. // Position is a 3D vector.
type Position struct { type Position struct {
X, Y, Z int X, Y, Z int

View File

@ -15,6 +15,7 @@ import (
"strings" "strings"
"github.com/Tnze/go-mc/data/packetid" "github.com/Tnze/go-mc/data/packetid"
"github.com/Tnze/go-mc/net"
"github.com/Tnze/go-mc/net/CFB8" "github.com/Tnze/go-mc/net/CFB8"
pk "github.com/Tnze/go-mc/net/packet" pk "github.com/Tnze/go-mc/net/packet"
) )
@ -26,7 +27,7 @@ type Auth struct {
AsTk string AsTk string
} }
func handleEncryptionRequest(c *Client, p pk.Packet) error { func handleEncryptionRequest(conn *net.Conn, c *Client, p pk.Packet) error {
// 创建AES对称加密密钥 // 创建AES对称加密密钥
key, encoStream, decoStream := newSymmetricEncryption() key, encoStream, decoStream := newSymmetricEncryption()
@ -48,13 +49,13 @@ func handleEncryptionRequest(c *Client, p pk.Packet) error {
return fmt.Errorf("gen encryption key response fail: %v", err) return fmt.Errorf("gen encryption key response fail: %v", err)
} }
err = c.Conn.WritePacket(p) err = conn.WritePacket(p)
if err != nil { if err != nil {
return err return err
} }
// 设置连接加密 // 设置连接加密
c.Conn.SetCipher(encoStream, decoStream) conn.SetCipher(encoStream, decoStream)
return nil return nil
} }

View File

@ -88,13 +88,13 @@ func (c *Client) join(addr string, options JoinOptions) error {
} }
// Dial connection // Dial connection
c.Conn, err = options.MCDialer.DialMCContext(options.Context, addr) conn, err := options.MCDialer.DialMCContext(options.Context, addr)
if err != nil { if err != nil {
return LoginErr{"connect server", err} return LoginErr{"connect server", err}
} }
// Handshake // Handshake
err = c.Conn.WritePacket(pk.Marshal( err = conn.WritePacket(pk.Marshal(
Handshake, Handshake,
pk.VarInt(ProtocolVersion), // Protocol version pk.VarInt(ProtocolVersion), // Protocol version
pk.String(host), // Host pk.String(host), // Host
@ -110,7 +110,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
Has: err == nil, Has: err == nil,
Val: pk.UUID(c.UUID), Val: pk.UUID(c.UUID),
} }
err = c.Conn.WritePacket(pk.Marshal( err = conn.WritePacket(pk.Marshal(
packetid.LoginStart, packetid.LoginStart,
pk.String(c.Auth.Name), pk.String(c.Auth.Name),
PlayerUUID, PlayerUUID,
@ -122,7 +122,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
for { for {
// Receive Packet // Receive Packet
var p pk.Packet var p pk.Packet
if err = c.Conn.ReadPacket(&p); err != nil { if err = conn.ReadPacket(&p); err != nil {
return LoginErr{receiving, err} return LoginErr{receiving, err}
} }
@ -137,7 +137,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
return LoginErr{"disconnect", DisconnectErr(reason)} return LoginErr{"disconnect", DisconnectErr(reason)}
case packetid.LoginEncryptionRequest: // Encryption Request case packetid.LoginEncryptionRequest: // Encryption Request
if err := handleEncryptionRequest(c, p); err != nil { if err := handleEncryptionRequest(conn, c, p); err != nil {
return LoginErr{"encryption", err} return LoginErr{"encryption", err}
} }
receiving = "set compression" receiving = "set compression"
@ -150,6 +150,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
if err != nil { if err != nil {
return LoginErr{"login success", err} return LoginErr{"login success", err}
} }
c.Conn = warpConn(conn)
return nil return nil
case packetid.LoginCompression: // Set Compression case packetid.LoginCompression: // Set Compression
@ -157,7 +158,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
if err := p.Scan(&threshold); err != nil { if err := p.Scan(&threshold); err != nil {
return LoginErr{"compression", err} return LoginErr{"compression", err}
} }
c.Conn.SetThreshold(int(threshold)) conn.SetThreshold(int(threshold))
receiving = "login success" receiving = "login success"
case packetid.LoginPluginRequest: // Login Plugin Request case packetid.LoginPluginRequest: // Login Plugin Request
@ -179,7 +180,7 @@ func (c *Client) join(addr string, options JoinOptions) error {
} }
} }
if err := c.Conn.WritePacket(pk.Marshal( if err := conn.WritePacket(pk.Marshal(
packetid.LoginPluginResponse, packetid.LoginPluginResponse,
msgid, PluginMessageData, msgid, PluginMessageData,
)); err != nil { )); err != nil {