From 6fdea28236771ee1d90a6fc959075c79939ad566 Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Sun, 8 Dec 2024 13:57:25 +0800 Subject: CAP: Primitive negotiation --- cap.go | 25 +++++++++++++++++++++++++ clients.go | 30 +++++++++++++++++++++--------- cmd_cap.go | 35 ++++++++++++++++++++++++++++++++++- cmd_nick.go | 4 ++-- cmd_user.go | 8 ++++---- main.go | 2 ++ 6 files changed, 88 insertions(+), 16 deletions(-) create mode 100644 cap.go diff --git a/cap.go b/cap.go new file mode 100644 index 0000000..83dfbda --- /dev/null +++ b/cap.go @@ -0,0 +1,25 @@ +package main + +import ( + "strings" +) + +var Caps = map[string]string{ + "sasl": "PLAIN,EXTERNAL", +} + +var capls string + +// Can't be in init() because Caps will be registered with init in the future +// and init()s are executed by filename alphabetical order +func setupCapls() { + capls = "" + for k, v := range Caps { + capls += k + if v != "" { + capls += "=" + v + } + capls += " " + } + capls = strings.TrimSuffix(capls, " ") +} diff --git a/clients.go b/clients.go index 07dc532..2719b08 100644 --- a/clients.go +++ b/clients.go @@ -15,6 +15,7 @@ type Client struct { Ident string Gecos string Host string + Caps map[string]struct{} Server Server State ClientState } @@ -65,6 +66,7 @@ func NewLocalClient(conn *net.Conn) (*Client, error) { Server: self, State: ClientStatePreRegistration, Nick: "*", + Caps: make(map[string]struct{}), } for range 10 { uid_ := []byte(self.SID) @@ -86,22 +88,32 @@ func NewLocalClient(conn *net.Conn) (*Client, error) { } func (client *Client) checkRegistration() error { - if client.State != ClientStatePreRegistration { - slog.Error("spurious call to checkRegistration", "client", client) - return ErrCallState - } - if client.Nick != "*" && client.Ident != "" { - return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome")) + switch client.State { + case ClientStatePreRegistration: + if client.Nick != "*" && client.Ident != "" { + client.State = ClientStateRegistered + return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome")) + } + return nil // Incomplete for registration + case ClientStateCapabilitiesFinished: + if client.Nick != "*" && client.Ident != "" { + client.State = ClientStateRegistered + return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome")) + } + return nil + default: + return nil } - return nil } type ClientState uint8 const ( - ClientStateRemote ClientState = iota - ClientStatePreRegistration + ClientStatePreRegistration ClientState = iota + ClientStateCapabilities + ClientStateCapabilitiesFinished ClientStateRegistered + ClientStateRemote ) var ( diff --git a/cmd_cap.go b/cmd_cap.go index 328bce6..c1ddb1d 100644 --- a/cmd_cap.go +++ b/cmd_cap.go @@ -16,13 +16,46 @@ func handleClientCap(msg RMsg, client *Client) error { } return nil } + if client.State == ClientStateRemote { + return ErrRemoteClient + } switch strings.ToUpper(msg.Params[0]) { case "LS": - err := client.Send(MakeMsg(self, "CAP", client.Nick, "LS", "sasl=PLAIN,EXTERNAL")) + if client.State == ClientStatePreRegistration { + client.State = ClientStateCapabilities + } + err := client.Send(MakeMsg(self, "CAP", client.Nick, "LS", capls)) + // TODO: Split when too long if err != nil { return err } case "REQ": + if client.State == ClientStatePreRegistration { + client.State = ClientStateCapabilities + } + caps := strings.Split(msg.Params[1], " ") + for _, c := range caps { + if c[0] == '-' { + // TODO: Remove capability + delete(client.Caps, c) + continue + } + _, ok := Caps[c] + if ok { + client.Send(MakeMsg(self, "CAP", client.Nick, "ACK", c)) + client.Caps[c] = struct{}{} + // TODO: This is terrible + } else { + client.Send(MakeMsg(self, "CAP", client.Nick, "NAK", c)) + } + } + case "END": + if client.State != ClientStateCapabilities { + // Just ignore it + return nil + } + client.State = ClientStateCapabilitiesFinished + client.checkRegistration() } return nil } diff --git a/cmd_nick.go b/cmd_nick.go index 9ac3cd8..f24a7fd 100644 --- a/cmd_nick.go +++ b/cmd_nick.go @@ -22,7 +22,7 @@ func handleClientNick(msg RMsg, client *Client) error { } } } else { - if !nickToClient.CompareAndDelete(client.Nick, client) { + if (client.State >= ClientStateRegistered || client.Nick != "*") && !nickToClient.CompareAndDelete(client.Nick, client) { slog.Error("nick inconsistent", "nick", client.Nick, "client", client) return fmt.Errorf("%w: %v", ErrInconsistentClient, client) } @@ -34,7 +34,7 @@ func handleClientNick(msg RMsg, client *Client) error { } client.Nick = msg.Params[0] } - if client.State == ClientStatePreRegistration { + if client.State < ClientStateRegistered { err := client.checkRegistration() if err != nil { return err diff --git a/cmd_user.go b/cmd_user.go index 4c4c86a..bf8df64 100644 --- a/cmd_user.go +++ b/cmd_user.go @@ -12,20 +12,20 @@ func handleClientUser(msg RMsg, client *Client) error { if len(msg.Params) < 4 { return client.Send(MakeMsg(self, ERR_NEEDMOREPARAMS, "USER", "Not enough parameters")) } - switch client.State { - case ClientStatePreRegistration: + switch { + case client.State < ClientStateRegistered: client.Ident = "~" + msg.Params[0] client.Gecos = msg.Params[3] err := client.checkRegistration() if err != nil { return err } - case ClientStateRegistered: + case client.State == ClientStateRegistered: err := client.Send(MakeMsg(self, ERR_ALREADYREGISTERED, client.Nick, "You may not reregister")) if err != nil { return err } - case ClientStateRemote: + case client.State == ClientStateRemote: } return nil } diff --git a/main.go b/main.go index bd96e5d..edb629a 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,8 @@ func main() { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) slog.SetDefault(logger) + setupCapls() + self = Server{ conn: nil, SID: "001", -- cgit v1.2.3