From cf25807f68481e8029cb2a882eddfa9c0c320150 Mon Sep 17 00:00:00 2001 From: Tnze Date: Sun, 6 Mar 2022 00:28:46 +0800 Subject: [PATCH] Support SRV records in net.Dial and bot.PingAndListContext --- bot/mcbot.go | 44 ++--------- bot/pinglist.go | 75 ++++++++++-------- examples/mcping/README.md | 7 +- examples/mcping/mcping.go | 91 +++++++++++----------- net/conn.go | 156 ++++++++++++++++++++++++++++++++++---- 5 files changed, 240 insertions(+), 133 deletions(-) diff --git a/bot/mcbot.go b/bot/mcbot.go index 8a001b3..8315dfc 100644 --- a/bot/mcbot.go +++ b/bot/mcbot.go @@ -6,7 +6,6 @@ package bot import ( "context" - "errors" "net" "strconv" @@ -18,55 +17,24 @@ import ( // ProtocolVersion is the protocol version number of minecraft net protocol const ProtocolVersion = 757 -const DefaultPort = 25565 +const DefaultPort = mcnet.DefaultPort // JoinServer connect a Minecraft server for playing the game. // Using roughly the same way to parse address as minecraft. func (c *Client) JoinServer(addr string) (err error) { - return c.join(&net.Dialer{}, addr) + return c.join(context.Background(), &mcnet.DefaultDialer, addr) } // JoinServerWithDialer is similar to JoinServer but using a Dialer. func (c *Client) JoinServerWithDialer(d *net.Dialer, addr string) (err error) { - return c.join(d, addr) + return c.join(context.Background(), &mcnet.Dialer{Dialer: d}, addr) } -// parseAddress will look up SRV records for the address -func parseAddress(r *net.Resolver, addr string) (string, error) { - var port uint16 - var addrErr *net.AddrError - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - if errors.As(err, &addrErr) { - host, port = addr, DefaultPort - } else { - return "", err - } - } else { - if portInt, err := strconv.ParseUint(portStr, 10, 16); err != nil { - port = DefaultPort - } else { - port = uint16(portInt) - } - } - - _, srvs, err := r.LookupSRV(context.TODO(), "minecraft", "tcp", host) - if err == nil && len(srvs) > 0 { - host, port = srvs[0].Target, srvs[0].Port - } - - return net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)), nil -} - -func (c *Client) join(d *net.Dialer, addr string) error { +func (c *Client) join(ctx context.Context, d *mcnet.Dialer, addr string) error { const Handshake = 0x00 - addrSrv, err := parseAddress(d.Resolver, addr) - if err != nil { - return LoginErr{"resolved address", err} - } // Split Host and Port - host, portStr, err := net.SplitHostPort(addrSrv) + host, portStr, err := net.SplitHostPort(addr) if err != nil { return LoginErr{"split address", err} } @@ -76,7 +44,7 @@ func (c *Client) join(d *net.Dialer, addr string) error { } // Dial connection - c.Conn, err = mcnet.DialMC(addrSrv) + c.Conn, err = d.DialMCContext(ctx, addr) if err != nil { return LoginErr{"connect server", err} } diff --git a/bot/pinglist.go b/bot/pinglist.go index e613397..864ccc7 100644 --- a/bot/pinglist.go +++ b/bot/pinglist.go @@ -1,8 +1,11 @@ package bot import ( + "context" + "errors" "fmt" "net" + "os" "strconv" "time" @@ -16,53 +19,63 @@ import ( // // For more information for JSON format, see https://wiki.vg/Server_List_Ping#Response func PingAndList(addr string) ([]byte, time.Duration, error) { - addrSrv, err := parseAddress(&net.Resolver{}, addr) - if err != nil { - return nil, 0, LoginErr{"parse address", err} - } - - conn, err := mcnet.DialMC(addrSrv) + conn, err := mcnet.DialMC(addr) if err != nil { return nil, 0, LoginErr{"dial connection", err} } - return pingAndList(addr, conn) + return pingAndList(context.Background(), addr, conn) } // PingAndListTimeout is the version of PingAndList with max request time. func PingAndListTimeout(addr string, timeout time.Duration) ([]byte, time.Duration, error) { - deadLine := time.Now().Add(timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return PingAndListContext(ctx, addr) +} - addrSrv, err := parseAddress(&net.Resolver{}, addr) - if err != nil { - return nil, 0, LoginErr{"parse address", err} - } - - conn, err := mcnet.DialMCTimeout(addrSrv, timeout) +func PingAndListContext(ctx context.Context, addr string) ([]byte, time.Duration, error) { + conn, err := mcnet.DefaultDialer.DialMCContext(ctx, addr) if err != nil { return nil, 0, err } - - err = conn.Socket.SetDeadline(deadLine) - if err != nil { - return nil, 0, LoginErr{"set deadline", err} - } - - return pingAndList(addr, conn) + return pingAndList(ctx, addr, conn) } -func pingAndList(addr string, conn *mcnet.Conn) ([]byte, time.Duration, error) { - addrSrv, err := parseAddress(nil, addr) - if err != nil { - return nil, 0, LoginErr{"resolved address", err} +func pingAndList(ctx context.Context, addr string, conn *mcnet.Conn) (data []byte, delay time.Duration, err error) { + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + if err := conn.Socket.SetDeadline(deadline); err != nil { + return nil, 0, err + } + defer func() { + // Reset deadline + if err2 := conn.Socket.SetDeadline(time.Time{}); err2 != nil { + if err2 == nil { + err = err2 + } + return + } + // Map error type + if errors.Is(err, os.ErrDeadlineExceeded) { + err = context.DeadlineExceeded + } + }() } // Split Host and Port - host, portStr, err := net.SplitHostPort(addrSrv) + host, portStr, err := net.SplitHostPort(addr) + var port uint64 if err != nil { - return nil, 0, LoginErr{"split address", err} - } - port, err := strconv.ParseUint(portStr, 0, 16) - if err != nil { - return nil, 0, LoginErr{"parse port", err} + var addrErr *net.AddrError + const missingPort = "missing port in address" + if errors.As(err, &addrErr) && addrErr.Err == missingPort { + host, port, err = addr, DefaultPort, nil + } else { + return nil, 0, LoginErr{"split address", err} + } + } else { + port, err = strconv.ParseUint(portStr, 0, 16) + if err != nil { + return nil, 0, LoginErr{"parse port", err} + } } const Handshake = 0x00 diff --git a/examples/mcping/README.md b/examples/mcping/README.md index c087b52..22222d6 100644 --- a/examples/mcping/README.md +++ b/examples/mcping/README.md @@ -1,10 +1,7 @@ # mcping -Ping tool for Minecraft: Java Edition. -Just for example. Not recommended for daily use. Use [github.com/go-mc/mcping](github.com/go-mc/mcping) instead, which including SRV parse. - +A ping tool for Minecraft: Java Edition. 适用于Minecraft: Java Edition的ping工具。 -只起示例作用,日常使用建议使用完整版[github.com/go-mc/mcping](github.com/go-mc/mcping),包含SRV解析等功能。 Install with go tools: ```go get -u github.com/Tnze/go-mc/cmd/mcping``` @@ -13,5 +10,5 @@ Install with go tools: Install with Homebrew: ```brew tap Tnze/tap && brew install mcping``` -Useage: +Usage: ```mcping [:port]``` diff --git a/examples/mcping/mcping.go b/examples/mcping/mcping.go index 65bca14..53bbf11 100644 --- a/examples/mcping/mcping.go +++ b/examples/mcping/mcping.go @@ -4,18 +4,25 @@ package main import ( "encoding/base64" "encoding/json" - "errors" + "flag" "fmt" + "image" + _ "image/jpeg" + "image/png" "os" "strings" "text/template" "time" + "github.com/google/uuid" + "github.com/Tnze/go-mc/bot" "github.com/Tnze/go-mc/chat" - "github.com/google/uuid" ) +var protocol = flag.Int("p", 578, "The protocol version number sent during ping") +var favicon = flag.String("f", "", "If specified, the server's icon will be save to") + type status struct { Description chat.Message Players struct { @@ -39,50 +46,15 @@ type status struct { // and prepended with "data:image/png;base64,". type Icon string -var IconFormatErr = errors.New("data format error") -var IconAbsentErr = errors.New("icon not present") - -// ToPNG decode base64-icon, return a PNG image -// Take care of there is no safety check, image may contain malicious code. -func (i Icon) ToPNG() ([]byte, error) { +func (i Icon) ToImage() (icon image.Image, err error) { const prefix = "data:image/png;base64," - if i == "" { - return nil, IconAbsentErr - } if !strings.HasPrefix(string(i), prefix) { - return nil, IconFormatErr + return nil, fmt.Errorf("server icon should prepended with %q", prefix) } - return base64.StdEncoding.DecodeString(strings.TrimPrefix(string(i), prefix)) -} - -func main() { - addr := getAddr() - fmt.Printf("MCPING (%s):", addr) - resp, delay, err := bot.PingAndList(addr) - if err != nil { - fmt.Printf("ping and list server fail: %v", err) - os.Exit(1) - } - - var s status - err = json.Unmarshal(resp, &s) - if err != nil { - fmt.Print("unmarshal resp fail:", err) - os.Exit(1) - } - s.Delay = delay - - fmt.Print(&s) -} - -func getAddr() string { - const usage = "Usage: mcping [:port]" - if len(os.Args) < 2 { - fmt.Println("no host name.", usage) - os.Exit(1) - } - - return os.Args[1] + base64png := strings.TrimPrefix(string(i), prefix) + r := base64.NewDecoder(base64.StdEncoding, strings.NewReader(base64png)) + icon, err = png.Decode(r) + return } var outTemp = template.Must(template.New("output").Parse(` @@ -102,3 +74,36 @@ func (s *status) String() string { } return sb.String() } + +func usage() { + _, _ = fmt.Fprintf(flag.CommandLine.Output(), "Usage:\n%s [-f] [-p]
[:port]\n", os.Args[0]) + flag.PrintDefaults() +} + +func main() { + flag.Parse() + flag.Usage = usage + addr := flag.Arg(0) + if addr == "" { + fmt.Println("") + flag.Usage() + os.Exit(2) + } + + fmt.Printf("MCPING (%s):", addr) + resp, delay, err := bot.PingAndList(addr) + if err != nil { + fmt.Printf("Ping and list server fail: %v", err) + os.Exit(1) + } + + var s status + err = json.Unmarshal(resp, &s) + if err != nil { + fmt.Print("Parse json response fail:", err) + os.Exit(1) + } + s.Delay = delay + + fmt.Print(&s) +} diff --git a/net/conn.go b/net/conn.go index b7d3b28..88e05d1 100644 --- a/net/conn.go +++ b/net/conn.go @@ -2,14 +2,19 @@ package net import ( + "context" "crypto/cipher" + "errors" "io" "net" + "strconv" "time" pk "github.com/Tnze/go-mc/net/packet" ) +const DefaultPort = 25565 + // A Listener is a minecraft Listener type Listener struct{ net.Listener } @@ -42,30 +47,149 @@ type Conn struct { threshold int } +var DefaultDialer = Dialer{} + // DialMC create a Minecraft connection +// Lookup SRV records only if port doesn't exist or equals to 0. func DialMC(addr string) (*Conn, error) { - conn, err := net.Dial("tcp", addr) - return &Conn{ - Socket: conn, - Reader: conn, - Writer: conn, - threshold: -1, - }, err + return DefaultDialer.DialMCContext(context.Background(), addr) } // DialMCTimeout acts like DialMC but takes a timeout. func DialMCTimeout(addr string, timeout time.Duration) (*Conn, error) { - conn, err := net.DialTimeout("tcp", addr, timeout) - return &Conn{ - Socket: conn, - Reader: conn, - Writer: conn, - threshold: -1, - }, err + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return DefaultDialer.DialMCContext(ctx, addr) } -// WrapConn warp an net.Conn to MC-Conn -// Helps you modify the connection process (eg. using DialContext). +type Dialer struct { + *net.Dialer +} + +func (d *Dialer) resolver() *net.Resolver { + if d.Resolver != nil { + return d.Resolver + } + return net.DefaultResolver +} + +func (d *Dialer) DialMCContext(ctx context.Context, addr string) (*Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + var addrErr *net.AddrError + const missingPort = "missing port in address" + if errors.As(err, &addrErr) && addrErr.Err == missingPort { + host, port, err = addr, "", nil + } else { + return nil, err + } + } + var ras []string + if port == "" { + // We look up SRV only if the port is not specified + _, srvRecords, err := d.resolver().LookupSRV(ctx, "minecraft", "tcp", host) + if err == nil { + for _, record := range srvRecords { + addr := net.JoinHostPort(record.Target, strconv.Itoa(int(record.Port))) + ras = append(ras, addr) + } + } + // Whatever the SRV records is found, + addr = net.JoinHostPort(addr, strconv.Itoa(DefaultPort)) + } + ras = append(ras, addr) + + var firstErr error + for i, addr := range ras { + select { + case <-ctx.Done(): + return nil, context.Canceled + default: + } + dialCtx := ctx + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i) + if err != nil { + // Ran out of time. + if firstErr == nil { + firstErr = context.DeadlineExceeded + } + break + } + if partialDeadline.Before(deadline) { + var cancel context.CancelFunc + dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) + defer cancel() + } + } + conn, err := d.DialContext(dialCtx, "tcp", addr) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + return WrapConn(conn), nil + } + return nil, firstErr +} + +// deadline returns the earliest of: +// - now+Timeout +// - d.Deadline +// - the context's deadline +// Or zero, if none of Timeout, Deadline, or context's deadline is set. +// +// Copied from net/dial.go +func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) { + if d.Timeout != 0 { // including negative, for historical reasons + earliest = now.Add(d.Timeout) + } + if d, ok := ctx.Deadline(); ok { + earliest = minNonzeroTime(earliest, d) + } + return minNonzeroTime(earliest, d.Deadline) +} + +// Copied from net/dial.go +func minNonzeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b + } + if b.IsZero() || a.Before(b) { + return a + } + return b +} + +// partialDeadline returns the deadline to use for a single address, +// when multiple addresses are pending. +// +// Copied from net/dial.go +func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { + if deadline.IsZero() { + return deadline, nil + } + timeRemaining := deadline.Sub(now) + if timeRemaining <= 0 { + return time.Time{}, context.DeadlineExceeded + } + // Tentatively allocate equal time to each remaining address. + timeout := timeRemaining / time.Duration(addrsRemaining) + // If the time per address is too short, steal from the end of the list. + const saneMinimum = 2 * time.Second + if timeout < saneMinimum { + if timeRemaining < saneMinimum { + timeout = timeRemaining + } else { + timeout = saneMinimum + } + } + return now.Add(timeout), nil +} + +// WrapConn warp a net.Conn to MC-Conn +// Helps you modify the connection process (e.g. using DialContext). func WrapConn(conn net.Conn) *Conn { return &Conn{ Socket: conn,