Support SRV records in net.Dial and bot.PingAndListContext

This commit is contained in:
Tnze
2022-03-06 00:28:46 +08:00
parent 23bcf9149a
commit cf25807f68
5 changed files with 240 additions and 133 deletions

View File

@ -1,8 +1,11 @@
package bot
import (
"context"
"errors"
"fmt"
"net"
"os"
"strconv"
"time"
@ -16,53 +19,63 @@ import (
//
// For more information for JSON format, see https://wiki.vg/Server_List_Ping#Response
func PingAndList(addr string) ([]byte, time.Duration, error) {
addrSrv, err := parseAddress(&net.Resolver{}, addr)
if err != nil {
return nil, 0, LoginErr{"parse address", err}
}
conn, err := mcnet.DialMC(addrSrv)
conn, err := mcnet.DialMC(addr)
if err != nil {
return nil, 0, LoginErr{"dial connection", err}
}
return pingAndList(addr, conn)
return pingAndList(context.Background(), addr, conn)
}
// PingAndListTimeout is the version of PingAndList with max request time.
func PingAndListTimeout(addr string, timeout time.Duration) ([]byte, time.Duration, error) {
deadLine := time.Now().Add(timeout)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return PingAndListContext(ctx, addr)
}
addrSrv, err := parseAddress(&net.Resolver{}, addr)
if err != nil {
return nil, 0, LoginErr{"parse address", err}
}
conn, err := mcnet.DialMCTimeout(addrSrv, timeout)
func PingAndListContext(ctx context.Context, addr string) ([]byte, time.Duration, error) {
conn, err := mcnet.DefaultDialer.DialMCContext(ctx, addr)
if err != nil {
return nil, 0, err
}
err = conn.Socket.SetDeadline(deadLine)
if err != nil {
return nil, 0, LoginErr{"set deadline", err}
}
return pingAndList(addr, conn)
return pingAndList(ctx, addr, conn)
}
func pingAndList(addr string, conn *mcnet.Conn) ([]byte, time.Duration, error) {
addrSrv, err := parseAddress(nil, addr)
if err != nil {
return nil, 0, LoginErr{"resolved address", err}
func pingAndList(ctx context.Context, addr string, conn *mcnet.Conn) (data []byte, delay time.Duration, err error) {
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
if err := conn.Socket.SetDeadline(deadline); err != nil {
return nil, 0, err
}
defer func() {
// Reset deadline
if err2 := conn.Socket.SetDeadline(time.Time{}); err2 != nil {
if err2 == nil {
err = err2
}
return
}
// Map error type
if errors.Is(err, os.ErrDeadlineExceeded) {
err = context.DeadlineExceeded
}
}()
}
// Split Host and Port
host, portStr, err := net.SplitHostPort(addrSrv)
host, portStr, err := net.SplitHostPort(addr)
var port uint64
if err != nil {
return nil, 0, LoginErr{"split address", err}
}
port, err := strconv.ParseUint(portStr, 0, 16)
if err != nil {
return nil, 0, LoginErr{"parse port", err}
var addrErr *net.AddrError
const missingPort = "missing port in address"
if errors.As(err, &addrErr) && addrErr.Err == missingPort {
host, port, err = addr, DefaultPort, nil
} else {
return nil, 0, LoginErr{"split address", err}
}
} else {
port, err = strconv.ParseUint(portStr, 0, 16)
if err != nil {
return nil, 0, LoginErr{"parse port", err}
}
}
const Handshake = 0x00