aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cap.go25
-rw-r--r--clients.go30
-rw-r--r--cmd_cap.go35
-rw-r--r--cmd_nick.go4
-rw-r--r--cmd_user.go8
-rw-r--r--main.go2
6 files changed, 88 insertions, 16 deletions
diff --git a/cap.go b/cap.go
new file mode 100644
index 0000000..83dfbda
--- /dev/null
+++ b/cap.go
@@ -0,0 +1,25 @@
+package main
+
+import (
+ "strings"
+)
+
+var Caps = map[string]string{
+ "sasl": "PLAIN,EXTERNAL",
+}
+
+var capls string
+
+// Can't be in init() because Caps will be registered with init in the future
+// and init()s are executed by filename alphabetical order
+func setupCapls() {
+ capls = ""
+ for k, v := range Caps {
+ capls += k
+ if v != "" {
+ capls += "=" + v
+ }
+ capls += " "
+ }
+ capls = strings.TrimSuffix(capls, " ")
+}
diff --git a/clients.go b/clients.go
index 07dc532..2719b08 100644
--- a/clients.go
+++ b/clients.go
@@ -15,6 +15,7 @@ type Client struct {
Ident string
Gecos string
Host string
+ Caps map[string]struct{}
Server Server
State ClientState
}
@@ -65,6 +66,7 @@ func NewLocalClient(conn *net.Conn) (*Client, error) {
Server: self,
State: ClientStatePreRegistration,
Nick: "*",
+ Caps: make(map[string]struct{}),
}
for range 10 {
uid_ := []byte(self.SID)
@@ -86,22 +88,32 @@ func NewLocalClient(conn *net.Conn) (*Client, error) {
}
func (client *Client) checkRegistration() error {
- if client.State != ClientStatePreRegistration {
- slog.Error("spurious call to checkRegistration", "client", client)
- return ErrCallState
- }
- if client.Nick != "*" && client.Ident != "" {
- return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome"))
+ switch client.State {
+ case ClientStatePreRegistration:
+ if client.Nick != "*" && client.Ident != "" {
+ client.State = ClientStateRegistered
+ return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome"))
+ }
+ return nil // Incomplete for registration
+ case ClientStateCapabilitiesFinished:
+ if client.Nick != "*" && client.Ident != "" {
+ client.State = ClientStateRegistered
+ return client.Send(MakeMsg(self, RPL_WELCOME, client.Nick, "Welcome"))
+ }
+ return nil
+ default:
+ return nil
}
- return nil
}
type ClientState uint8
const (
- ClientStateRemote ClientState = iota
- ClientStatePreRegistration
+ ClientStatePreRegistration ClientState = iota
+ ClientStateCapabilities
+ ClientStateCapabilitiesFinished
ClientStateRegistered
+ ClientStateRemote
)
var (
diff --git a/cmd_cap.go b/cmd_cap.go
index 328bce6..c1ddb1d 100644
--- a/cmd_cap.go
+++ b/cmd_cap.go
@@ -16,13 +16,46 @@ func handleClientCap(msg RMsg, client *Client) error {
}
return nil
}
+ if client.State == ClientStateRemote {
+ return ErrRemoteClient
+ }
switch strings.ToUpper(msg.Params[0]) {
case "LS":
- err := client.Send(MakeMsg(self, "CAP", client.Nick, "LS", "sasl=PLAIN,EXTERNAL"))
+ if client.State == ClientStatePreRegistration {
+ client.State = ClientStateCapabilities
+ }
+ err := client.Send(MakeMsg(self, "CAP", client.Nick, "LS", capls))
+ // TODO: Split when too long
if err != nil {
return err
}
case "REQ":
+ if client.State == ClientStatePreRegistration {
+ client.State = ClientStateCapabilities
+ }
+ caps := strings.Split(msg.Params[1], " ")
+ for _, c := range caps {
+ if c[0] == '-' {
+ // TODO: Remove capability
+ delete(client.Caps, c)
+ continue
+ }
+ _, ok := Caps[c]
+ if ok {
+ client.Send(MakeMsg(self, "CAP", client.Nick, "ACK", c))
+ client.Caps[c] = struct{}{}
+ // TODO: This is terrible
+ } else {
+ client.Send(MakeMsg(self, "CAP", client.Nick, "NAK", c))
+ }
+ }
+ case "END":
+ if client.State != ClientStateCapabilities {
+ // Just ignore it
+ return nil
+ }
+ client.State = ClientStateCapabilitiesFinished
+ client.checkRegistration()
}
return nil
}
diff --git a/cmd_nick.go b/cmd_nick.go
index 9ac3cd8..f24a7fd 100644
--- a/cmd_nick.go
+++ b/cmd_nick.go
@@ -22,7 +22,7 @@ func handleClientNick(msg RMsg, client *Client) error {
}
}
} else {
- if !nickToClient.CompareAndDelete(client.Nick, client) {
+ if (client.State >= ClientStateRegistered || client.Nick != "*") && !nickToClient.CompareAndDelete(client.Nick, client) {
slog.Error("nick inconsistent", "nick", client.Nick, "client", client)
return fmt.Errorf("%w: %v", ErrInconsistentClient, client)
}
@@ -34,7 +34,7 @@ func handleClientNick(msg RMsg, client *Client) error {
}
client.Nick = msg.Params[0]
}
- if client.State == ClientStatePreRegistration {
+ if client.State < ClientStateRegistered {
err := client.checkRegistration()
if err != nil {
return err
diff --git a/cmd_user.go b/cmd_user.go
index 4c4c86a..bf8df64 100644
--- a/cmd_user.go
+++ b/cmd_user.go
@@ -12,20 +12,20 @@ func handleClientUser(msg RMsg, client *Client) error {
if len(msg.Params) < 4 {
return client.Send(MakeMsg(self, ERR_NEEDMOREPARAMS, "USER", "Not enough parameters"))
}
- switch client.State {
- case ClientStatePreRegistration:
+ switch {
+ case client.State < ClientStateRegistered:
client.Ident = "~" + msg.Params[0]
client.Gecos = msg.Params[3]
err := client.checkRegistration()
if err != nil {
return err
}
- case ClientStateRegistered:
+ case client.State == ClientStateRegistered:
err := client.Send(MakeMsg(self, ERR_ALREADYREGISTERED, client.Nick, "You may not reregister"))
if err != nil {
return err
}
- case ClientStateRemote:
+ case client.State == ClientStateRemote:
}
return nil
}
diff --git a/main.go b/main.go
index bd96e5d..edb629a 100644
--- a/main.go
+++ b/main.go
@@ -12,6 +12,8 @@ func main() {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
+ setupCapls()
+
self = Server{
conn: nil,
SID: "001",