From 6fdea28236771ee1d90a6fc959075c79939ad566 Mon Sep 17 00:00:00 2001
From: Runxi Yu <me@runxiyu.org>
Date: Sun, 8 Dec 2024 13:57:25 +0800
Subject: CAP: Primitive negotiation

---
 cap.go      | 25 +++++++++++++++++++++++++
 clients.go  | 30 +++++++++++++++++++++---------
 cmd_cap.go  | 35 ++++++++++++++++++++++++++++++++++-
 cmd_nick.go |  4 ++--
 cmd_user.go |  8 ++++----
 main.go     |  2 ++
 6 files changed, 88 insertions(+), 16 deletions(-)
 create mode 100644 cap.go

diff --git a/cap.go b/cap.go
new file mode 100644
index 0000000..83dfbda
--- /dev/null
+++ b/cap.go
@@ -0,0 +1,25 @@
+package main
+
+import (
+	"strings"
+)
+
+var Caps = map[string]string{
+	"sasl": "PLAIN,EXTERNAL",
+}
+
+var capls string
+
+// Can't be in init() because Caps will be registered with init in the future
+// and init()s are executed by filename alphabetical order
+func setupCapls() {
+	capls = ""
+	for k, v := range Caps {
+		capls += k
+		if v != "" {
+			capls += "=" + v
+		}
+		capls += " "
+	}
+	capls = strings.TrimSuffix(capls, " ")
+}
diff --git a/clients.go b/clients.go
index 07dc532..2719b08 100644
--- a/clients.go
+++ b/clients.go
@@ -15,6 +15,7 @@ type Client struct {
 	Ident  string
 	Gecos  string
 	Host   string
+	Caps   map[string]struct{}
 	Server Server
 	State  ClientState
 }
@@ -65,6 +66,7 @@ func NewLocalClient(conn *net.Conn) (*Client, error) {
 		Server: self,
 		State:  ClientStatePreRegistration,
 		Nick:   "*",
+		Caps:   make(map[string]struct{}),
 	}
 	for range 10 {
 		uid_ := []byte(self.SID)
@@ -86,22 +88,32 @@ func NewLocalClient(conn *net.Conn) (*Client, error) {
 }
 
 func (client *Client) checkRegistration() error {
-	if client.State != ClientStatePreRegistration {
-		slog.Error("spurious call to checkRegistration", "client", client)
-		return ErrCallState
-	}
-	if client.Nick != "*" && client.Ident != "" {
-		return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome"))
+	switch client.State {
+	case ClientStatePreRegistration:
+		if client.Nick != "*" && client.Ident != "" {
+			client.State = ClientStateRegistered
+			return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome"))
+		}
+		return nil // Incomplete for registration
+	case ClientStateCapabilitiesFinished:
+		if client.Nick != "*" && client.Ident != "" {
+			client.State = ClientStateRegistered
+			return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome"))
+		}
+		return nil
+	default:
+		return nil
 	}
-	return nil
 }
 
 type ClientState uint8
 
 const (
-	ClientStateRemote ClientState = iota
-	ClientStatePreRegistration
+	ClientStatePreRegistration ClientState = iota
+	ClientStateCapabilities
+	ClientStateCapabilitiesFinished
 	ClientStateRegistered
+	ClientStateRemote
 )
 
 var (
diff --git a/cmd_cap.go b/cmd_cap.go
index 328bce6..c1ddb1d 100644
--- a/cmd_cap.go
+++ b/cmd_cap.go
@@ -16,13 +16,46 @@ func handleClientCap(msg RMsg, client *Client) error {
 		}
 		return nil
 	}
+	if client.State == ClientStateRemote {
+		return ErrRemoteClient
+	}
 	switch strings.ToUpper(msg.Params[0]) {
 	case "LS":
-		err := client.Send(MakeMsg(self, "CAP", client.Nick, "LS", "sasl=PLAIN,EXTERNAL"))
+		if client.State == ClientStatePreRegistration {
+			client.State = ClientStateCapabilities
+		}
+		err := client.Send(MakeMsg(self, "CAP", client.Nick, "LS", capls))
+		// TODO: Split when too long
 		if err != nil {
 			return err
 		}
 	case "REQ":
+		if client.State == ClientStatePreRegistration {
+			client.State = ClientStateCapabilities
+		}
+		caps := strings.Split(msg.Params[1], " ")
+		for _, c := range caps {
+			if c[0] == '-' {
+				// TODO: Remove capability
+				delete(client.Caps, c)
+				continue
+			}
+			_, ok := Caps[c]
+			if ok {
+				client.Send(MakeMsg(self, "CAP", client.Nick, "ACK", c))
+				client.Caps[c] = struct{}{}
+				// TODO: This is terrible
+			} else {
+				client.Send(MakeMsg(self, "CAP", client.Nick, "NAK", c))
+			}
+		}
+	case "END":
+		if client.State != ClientStateCapabilities {
+			// Just ignore it
+			return nil
+		}
+		client.State = ClientStateCapabilitiesFinished
+		client.checkRegistration()
 	}
 	return nil
 }
diff --git a/cmd_nick.go b/cmd_nick.go
index 9ac3cd8..f24a7fd 100644
--- a/cmd_nick.go
+++ b/cmd_nick.go
@@ -22,7 +22,7 @@ func handleClientNick(msg RMsg, client *Client) error {
 			}
 		}
 	} else {
-		if !nickToClient.CompareAndDelete(client.Nick, client) {
+		if (client.State >= ClientStateRegistered || client.Nick != "*") && !nickToClient.CompareAndDelete(client.Nick, client) {
 			slog.Error("nick inconsistent", "nick", client.Nick, "client", client)
 			return fmt.Errorf("%w: %v", ErrInconsistentClient, client)
 		}
@@ -34,7 +34,7 @@ func handleClientNick(msg RMsg, client *Client) error {
 		}
 		client.Nick = msg.Params[0]
 	}
-	if client.State == ClientStatePreRegistration {
+	if client.State < ClientStateRegistered {
 		err := client.checkRegistration()
 		if err != nil {
 			return err
diff --git a/cmd_user.go b/cmd_user.go
index 4c4c86a..bf8df64 100644
--- a/cmd_user.go
+++ b/cmd_user.go
@@ -12,20 +12,20 @@ func handleClientUser(msg RMsg, client *Client) error {
 	if len(msg.Params) < 4 {
 		return client.Send(MakeMsg(self, ERR_NEEDMOREPARAMS, "USER", "Not enough parameters"))
 	}
-	switch client.State {
-	case ClientStatePreRegistration:
+	switch {
+	case client.State < ClientStateRegistered:
 		client.Ident = "~" + msg.Params[0]
 		client.Gecos = msg.Params[3]
 		err := client.checkRegistration()
 		if err != nil {
 			return err
 		}
-	case ClientStateRegistered:
+	case client.State == ClientStateRegistered:
 		err := client.Send(MakeMsg(self, ERR_ALREADYREGISTERED, client.Nick, "You may not reregister"))
 		if err != nil {
 			return err
 		}
-	case ClientStateRemote:
+	case client.State == ClientStateRemote:
 	}
 	return nil
 }
diff --git a/main.go b/main.go
index bd96e5d..edb629a 100644
--- a/main.go
+++ b/main.go
@@ -12,6 +12,8 @@ func main() {
 	logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
 	slog.SetDefault(logger)
 
+	setupCapls()
+
 	self = Server{
 		conn: nil,
 		SID:  "001",
-- 
cgit v1.2.3