Support SRV records in net.Dial and bot.PingAndListContext

This commit is contained in:
Tnze
2022-03-06 00:28:46 +08:00
parent 23bcf9149a
commit cf25807f68
5 changed files with 240 additions and 133 deletions

View File

@ -2,14 +2,19 @@
package net
import (
"context"
"crypto/cipher"
"errors"
"io"
"net"
"strconv"
"time"
pk "github.com/Tnze/go-mc/net/packet"
)
const DefaultPort = 25565
// A Listener is a minecraft Listener
type Listener struct{ net.Listener }
@ -42,30 +47,149 @@ type Conn struct {
threshold int
}
var DefaultDialer = Dialer{}
// DialMC create a Minecraft connection
// Lookup SRV records only if port doesn't exist or equals to 0.
func DialMC(addr string) (*Conn, error) {
conn, err := net.Dial("tcp", addr)
return &Conn{
Socket: conn,
Reader: conn,
Writer: conn,
threshold: -1,
}, err
return DefaultDialer.DialMCContext(context.Background(), addr)
}
// DialMCTimeout acts like DialMC but takes a timeout.
func DialMCTimeout(addr string, timeout time.Duration) (*Conn, error) {
conn, err := net.DialTimeout("tcp", addr, timeout)
return &Conn{
Socket: conn,
Reader: conn,
Writer: conn,
threshold: -1,
}, err
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return DefaultDialer.DialMCContext(ctx, addr)
}
// WrapConn warp an net.Conn to MC-Conn
// Helps you modify the connection process (eg. using DialContext).
type Dialer struct {
*net.Dialer
}
func (d *Dialer) resolver() *net.Resolver {
if d.Resolver != nil {
return d.Resolver
}
return net.DefaultResolver
}
func (d *Dialer) DialMCContext(ctx context.Context, addr string) (*Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
var addrErr *net.AddrError
const missingPort = "missing port in address"
if errors.As(err, &addrErr) && addrErr.Err == missingPort {
host, port, err = addr, "", nil
} else {
return nil, err
}
}
var ras []string
if port == "" {
// We look up SRV only if the port is not specified
_, srvRecords, err := d.resolver().LookupSRV(ctx, "minecraft", "tcp", host)
if err == nil {
for _, record := range srvRecords {
addr := net.JoinHostPort(record.Target, strconv.Itoa(int(record.Port)))
ras = append(ras, addr)
}
}
// Whatever the SRV records is found,
addr = net.JoinHostPort(addr, strconv.Itoa(DefaultPort))
}
ras = append(ras, addr)
var firstErr error
for i, addr := range ras {
select {
case <-ctx.Done():
return nil, context.Canceled
default:
}
dialCtx := ctx
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
if err != nil {
// Ran out of time.
if firstErr == nil {
firstErr = context.DeadlineExceeded
}
break
}
if partialDeadline.Before(deadline) {
var cancel context.CancelFunc
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
defer cancel()
}
}
conn, err := d.DialContext(dialCtx, "tcp", addr)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
return WrapConn(conn), nil
}
return nil, firstErr
}
// deadline returns the earliest of:
// - now+Timeout
// - d.Deadline
// - the context's deadline
// Or zero, if none of Timeout, Deadline, or context's deadline is set.
//
// Copied from net/dial.go
func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
if d.Timeout != 0 { // including negative, for historical reasons
earliest = now.Add(d.Timeout)
}
if d, ok := ctx.Deadline(); ok {
earliest = minNonzeroTime(earliest, d)
}
return minNonzeroTime(earliest, d.Deadline)
}
// Copied from net/dial.go
func minNonzeroTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
if b.IsZero() || a.Before(b) {
return a
}
return b
}
// partialDeadline returns the deadline to use for a single address,
// when multiple addresses are pending.
//
// Copied from net/dial.go
func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
if deadline.IsZero() {
return deadline, nil
}
timeRemaining := deadline.Sub(now)
if timeRemaining <= 0 {
return time.Time{}, context.DeadlineExceeded
}
// Tentatively allocate equal time to each remaining address.
timeout := timeRemaining / time.Duration(addrsRemaining)
// If the time per address is too short, steal from the end of the list.
const saneMinimum = 2 * time.Second
if timeout < saneMinimum {
if timeRemaining < saneMinimum {
timeout = timeRemaining
} else {
timeout = saneMinimum
}
}
return now.Add(timeout), nil
}
// WrapConn warp a net.Conn to MC-Conn
// Helps you modify the connection process (e.g. using DialContext).
func WrapConn(conn net.Conn) *Conn {
return &Conn{
Socket: conn,