diff --git a/net/rcon.go b/net/rcon.go index f2a964b..691254f 100644 --- a/net/rcon.go +++ b/net/rcon.go @@ -1,6 +1,7 @@ package net import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -8,46 +9,60 @@ import ( "net" ) -func DialRCON(addr string, password string) (c *RCONConn, err error) { - c = &RCONConn{reqID: rand.Int31()} +func DialRCON(addr string, password string) (client RCONClientConn, err error) { + c := &RCONConn{ReqID: rand.Int31()} + client = c c.Conn, err = net.Dial("tcp", addr) if err != nil { + err = fmt.Errorf("connect fail: %v", err) return } //Login - err = c.WritePacket(3, password) + err = c.WritePacket(c.ReqID, 3, password) if err != nil { + err = fmt.Errorf("login fail: %v", err) return } - t, p, err := c.ReadPacket() + //Login resp + r, _, _, err := c.ReadPacket() if err != nil { + err = fmt.Errorf("read login resp fail: %v", err) return } - fmt.Print(t, p) + + if r == c.ReqID { + err = nil + } else if r == -1 { + err = errors.New("login fail") + } else { + err = errors.New("req id not match") + } return } type RCONConn struct { net.Conn - reqID int32 + ReqID int32 } -func (c *RCONConn) ReadPacket() (Type int32, Payload string, err error) { +func (r *RCONConn) ReadPacket() (RequestID, Type int32, Payload string, err error) { //read packet length var Length int32 - err = binary.Read(c, binary.LittleEndian, &Length) + err = binary.Read(r, binary.LittleEndian, &Length) if err != nil { + err = fmt.Errorf("read packet length fail: %v", err) return } //read packet data buf := make([]byte, Length) - err = binary.Read(c, binary.LittleEndian, &buf) + err = binary.Read(r, binary.LittleEndian, &buf) if err != nil { + err = fmt.Errorf("read packet body fail: %v", err) return } @@ -57,27 +72,129 @@ func (c *RCONConn) ReadPacket() (Type int32, Payload string, err error) { return } - RequestID := int32(binary.LittleEndian.Uint32(buf[:4])) + RequestID = int32(binary.LittleEndian.Uint32(buf[:4])) Type = int32(binary.LittleEndian.Uint32(buf[4:8])) Payload = string(buf[8 : Length-2]) - if RequestID == -1 { - err = errors.New("login fail") - } else if RequestID != c.reqID { - err = errors.New("request ID not match") + return +} + +func (r *RCONConn) WritePacket(RequestID, Type int32, Payload string) error { + buf := new(bytes.Buffer) + for _, v := range []interface{}{ + int32(4 + 4 + len(Payload) + 2), //Length + RequestID, //Request ID + Type, //Type + []byte(Payload), //Payload + []byte{0, 0}, //pad + } { + err := binary.Write(buf, binary.LittleEndian, v) + if err != nil { + return err + } + } + + _, err := r.Write(buf.Bytes()) + return err +} + +func (r *RCONConn) Cmd(cmd string) error { + err := r.WritePacket(r.ReqID, 2, cmd) + return err +} + +func (r *RCONConn) Resp() (resp string, err error) { + var ReqID, Type int32 + ReqID, Type, resp, err = r.ReadPacket() + if err != nil { + return + } + + if ReqID != r.ReqID { + err = errors.New("req ID not match") + } else if Type != 0 { + err = fmt.Errorf("packet type wrong: %d", Type) } return } -func (c *RCONConn) WritePacket(Type int32, Payload string) error { - err := binary.Write(c, binary.LittleEndian, []interface{}{ - int32(4 + 4 + len(Payload) + 2), //Length - c.reqID, //Request ID - Type, //Type - []byte(Payload), //Payload - [2]byte{0, 0}, //pad - }) +func (r *RCONConn) AcceptLogin(password string) error { + R, T, P, err := r.ReadPacket() + if err != nil { + return err + } - return err + r.ReqID = R + + //Check packet type + if T != 3 { + return fmt.Errorf("not a login packet: %d", T) + } + + if P != password { + err = r.WritePacket(-1, 2, "") + if err != nil { + return err + } + return errors.New("password wrong") + } + + err = r.WritePacket(R, 2, "") + if err != nil { + return err + } + + return nil +} + +func (r *RCONConn) AcceptCmd() (string, error) { + R, T, P, err := r.ReadPacket() + if err != nil { + return P, err + } + + r.ReqID = R + + //Check packet type + if T != 2 { + return P, fmt.Errorf("not a command packet: %d", T) + } + + return P, nil +} + +func (r *RCONConn) RespCmd(resp string) error { + return r.WritePacket(r.ReqID, 0, resp) +} + +type RCONClientConn interface { + Cmd(cmd string) error + Resp() (resp string, err error) +} + +type RCONServerConn interface { + AcceptLogin(password string) error + AcceptCmd() (cmd string, err error) + RespCmd(resp string) error +} + +func ListenRCON(addr string) (*RCONListener, error) { + l, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + return &RCONListener{Listener: l}, nil +} + +type RCONListener struct{ net.Listener } + +func (r *RCONListener) Accept() (RCONServerConn, error) { + conn, err := r.Listener.Accept() + if err != nil { + return nil, err + } + + return &RCONConn{Conn: conn}, nil } diff --git a/net/rcon_test.go b/net/rcon_test.go index 9d9f1a1..a93a0e3 100644 --- a/net/rcon_test.go +++ b/net/rcon_test.go @@ -1 +1,69 @@ package net + +import ( + "fmt" + "testing" +) + +func Test(t *testing.T) { + p := make(chan int, 1) + go server(t, p) + <-p + client(t) +} + +func server(t *testing.T, prepare chan<- int) { + l, err := ListenRCON("localhost:25575") + if err != nil { + t.Fatal(err) + } + prepare <- 1 + + for { + conn, err := l.Accept() + if err != nil { + t.Fatal(err) + } + go func(conn RCONServerConn) { + err := conn.AcceptLogin("RightPassword") + if err != nil { + t.Fatal("password wrong") + } + + for { + cmd, err := conn.AcceptCmd() + if err != nil { + t.Log(err) + return + } + resp := handleCommand(cmd) + err = conn.RespCmd(resp) + if err != nil { + t.Fatal(err) + } + } + }(conn) + } +} + +func handleCommand(cmd string) (resp string) { + return fmt.Sprintf("your command is %q", cmd) +} + +func client(t *testing.T) { + conn, err := DialRCON("localhost:25575", "RightPassword") + if err != nil { + t.Fatal(err) + } + + err = conn.Cmd("TEST COMMAND") + if err != nil { + t.Fatal(err) + } + + resp, err := conn.Resp() + if err != nil { + t.Fatal(err) + } + t.Logf("Server response: %q", resp) +}