diff --git a/bot/basic/basic.go b/bot/basic/basic.go index 94a1364..a923015 100644 --- a/bot/basic/basic.go +++ b/bot/basic/basic.go @@ -51,7 +51,7 @@ func (p *Player) Respawn() error { const PerformRespawn = 0 err := p.c.Conn.WritePacket(pk.Marshal( - int32(packetid.ServerboundClientCommand), + packetid.ServerboundClientCommand, pk.VarInt(PerformRespawn), )) if err != nil { diff --git a/bot/basic/info.go b/bot/basic/info.go index fe6460a..63e888a 100644 --- a/bot/basic/info.go +++ b/bot/basic/info.go @@ -54,7 +54,7 @@ func (p *Player) handleLoginPacket(packet pk.Packet) error { return Error{err} } err = p.c.Conn.WritePacket(pk.Marshal( // PluginMessage packet - int32(packetid.ServerboundCustomPayload), + packetid.ServerboundCustomPayload, pk.Identifier("minecraft:brand"), pk.String(p.Settings.Brand), )) @@ -63,7 +63,7 @@ func (p *Player) handleLoginPacket(packet pk.Packet) error { } err = p.c.Conn.WritePacket(pk.Marshal( - int32(packetid.ServerboundClientInformation), // Client settings + packetid.ServerboundClientInformation, // Client settings pk.String(p.Settings.Locale), pk.Byte(p.Settings.ViewDistance), pk.VarInt(p.Settings.ChatMode), diff --git a/bot/client.go b/bot/client.go index 46f2bc2..0b23b62 100644 --- a/bot/client.go +++ b/bot/client.go @@ -1,15 +1,19 @@ package bot import ( + "errors" + "github.com/google/uuid" "github.com/Tnze/go-mc/data/packetid" "github.com/Tnze/go-mc/net" + pk "github.com/Tnze/go-mc/net/packet" + "github.com/Tnze/go-mc/net/queue" ) // Client is used to access Minecraft server type Client struct { - Conn *net.Conn + Conn *Conn Auth Auth Name string @@ -37,6 +41,71 @@ func NewClient() *Client { } } +// Conn is a concurrently-safe warpper of net.Conn with packet queue. +// Note that not all methods are concurrently-safe. +type Conn struct { + conn *net.Conn + send, recv queue.Queue[pk.Packet] + rerr error +} + +func warpConn(c *net.Conn) *Conn { + wc := Conn{ + conn: c, + send: make(queue.ChannelQueue[pk.Packet], 256), + recv: make(queue.ChannelQueue[pk.Packet], 256), + rerr: nil, + } + go func() { + for { + var p pk.Packet + if err := c.ReadPacket(&p); err != nil { + wc.rerr = err + break + } + if ok := wc.recv.Push(p); !ok { + break + } + } + wc.recv.Close() + }() + go func() { + for { + p, ok := wc.send.Pull() + if !ok { + break + } + if err := c.WritePacket(p); err != nil { + break + } + } + }() + + return &wc +} + +func (c *Conn) ReadPacket(p *pk.Packet) error { + packet, ok := c.recv.Pull() + if !ok { + return c.rerr + } + *p = packet + return nil +} + +func (c *Conn) WritePacket(p pk.Packet) error { + ok := c.send.Push(p) + if !ok { + return errors.New("queue is full") + } + return nil +} + +func (c *Conn) Close() error { + c.send.Close() + return c.conn.Close() +} + // Position is a 3D vector. type Position struct { X, Y, Z int diff --git a/bot/login.go b/bot/login.go index 2b6e9e4..be3642b 100644 --- a/bot/login.go +++ b/bot/login.go @@ -15,6 +15,7 @@ import ( "strings" "github.com/Tnze/go-mc/data/packetid" + "github.com/Tnze/go-mc/net" "github.com/Tnze/go-mc/net/CFB8" pk "github.com/Tnze/go-mc/net/packet" ) @@ -26,7 +27,7 @@ type Auth struct { AsTk string } -func handleEncryptionRequest(c *Client, p pk.Packet) error { +func handleEncryptionRequest(conn *net.Conn, c *Client, p pk.Packet) error { // 创建AES对称加密密钥 key, encoStream, decoStream := newSymmetricEncryption() @@ -48,13 +49,13 @@ func handleEncryptionRequest(c *Client, p pk.Packet) error { return fmt.Errorf("gen encryption key response fail: %v", err) } - err = c.Conn.WritePacket(p) + err = conn.WritePacket(p) if err != nil { return err } // 设置连接加密 - c.Conn.SetCipher(encoStream, decoStream) + conn.SetCipher(encoStream, decoStream) return nil } diff --git a/bot/mcbot.go b/bot/mcbot.go index 86a2a59..7dc3dcc 100644 --- a/bot/mcbot.go +++ b/bot/mcbot.go @@ -88,13 +88,13 @@ func (c *Client) join(addr string, options JoinOptions) error { } // Dial connection - c.Conn, err = options.MCDialer.DialMCContext(options.Context, addr) + conn, err := options.MCDialer.DialMCContext(options.Context, addr) if err != nil { return LoginErr{"connect server", err} } // Handshake - err = c.Conn.WritePacket(pk.Marshal( + err = conn.WritePacket(pk.Marshal( Handshake, pk.VarInt(ProtocolVersion), // Protocol version pk.String(host), // Host @@ -110,7 +110,7 @@ func (c *Client) join(addr string, options JoinOptions) error { Has: err == nil, Val: pk.UUID(c.UUID), } - err = c.Conn.WritePacket(pk.Marshal( + err = conn.WritePacket(pk.Marshal( packetid.LoginStart, pk.String(c.Auth.Name), PlayerUUID, @@ -122,7 +122,7 @@ func (c *Client) join(addr string, options JoinOptions) error { for { // Receive Packet var p pk.Packet - if err = c.Conn.ReadPacket(&p); err != nil { + if err = conn.ReadPacket(&p); err != nil { return LoginErr{receiving, err} } @@ -137,7 +137,7 @@ func (c *Client) join(addr string, options JoinOptions) error { return LoginErr{"disconnect", DisconnectErr(reason)} case packetid.LoginEncryptionRequest: // Encryption Request - if err := handleEncryptionRequest(c, p); err != nil { + if err := handleEncryptionRequest(conn, c, p); err != nil { return LoginErr{"encryption", err} } receiving = "set compression" @@ -150,6 +150,7 @@ func (c *Client) join(addr string, options JoinOptions) error { if err != nil { return LoginErr{"login success", err} } + c.Conn = warpConn(conn) return nil case packetid.LoginCompression: // Set Compression @@ -157,7 +158,7 @@ func (c *Client) join(addr string, options JoinOptions) error { if err := p.Scan(&threshold); err != nil { return LoginErr{"compression", err} } - c.Conn.SetThreshold(int(threshold)) + conn.SetThreshold(int(threshold)) receiving = "login success" case packetid.LoginPluginRequest: // Login Plugin Request @@ -179,7 +180,7 @@ func (c *Client) join(addr string, options JoinOptions) error { } } - if err := c.Conn.WritePacket(pk.Marshal( + if err := conn.WritePacket(pk.Marshal( packetid.LoginPluginResponse, msgid, PluginMessageData, )); err != nil {