aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cap_sasl.go86
-rw-r--r--clients.go8
-rw-r--r--cmd_cap.go15
-rw-r--r--panics.go5
4 files changed, 109 insertions, 5 deletions
diff --git a/cap_sasl.go b/cap_sasl.go
new file mode 100644
index 0000000..7044efb
--- /dev/null
+++ b/cap_sasl.go
@@ -0,0 +1,86 @@
+package main
+
+import (
+ "encoding/base64"
+ "bytes"
+)
+
+type ExtraSasl struct {
+ AuthMethod string
+}
+
+const (
+ RPL_LOGGEDIN = "900"
+ RPL_LOGGEDOUT = "901"
+ ERR_NICKLOCKED = "902"
+ RPL_SASLSUCCESS = "903"
+ ERR_SASLFAIL = "904"
+ ERR_SASLTOOLONG = "905"
+ ERR_SASLABORTED = "906"
+ ERR_SASLALREADY = "907"
+ RPL_SASLMECHS = "908"
+)
+
+const (
+ panicSaslMethod = "stored illegal SASL method"
+)
+
+func init() {
+ Caps["sasl"] = "PLAIN,EXTERNAL"
+ CommandHandlers["AUTHENTICATE"] = handleClientAuthenticate
+}
+
+func handleClientAuthenticate(msg RMsg, client *Client) error {
+ _, ok := client.Caps["sasl"]
+ if !ok {
+ return client.Send(MakeMsg(self, "TODO", "you're trying to sasl without requesting for it"))
+ }
+
+ if len(msg.Params) < 1 {
+ return client.Send(MakeMsg(self, ERR_NEEDMOREPARAMS, "AUTHENTICATE", "Not enough parameters"))
+ }
+
+ extraSasl_, ok := client.Extra["sasl"]
+ if !ok {
+ client.Extra["sasl"] = &ExtraSasl{}
+ extraSasl_ = client.Extra["sasl"]
+ }
+ extraSasl, ok := extraSasl_.(*ExtraSasl)
+ if !ok {
+ panic(panicType)
+ }
+
+ switch extraSasl.AuthMethod {
+ case "":
+ if msg.Params[0] != "PLAIN" && msg.Params[0] != "EXTERNAL" {
+ return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed (invalid method)"))
+ }
+ extraSasl.AuthMethod = msg.Params[0]
+ return client.Send(MakeMsg(self, "AUTHENTICATE", "+"))
+ case "*": // Abort
+ extraSasl.AuthMethod = ""
+ return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed (aborted)"))
+ case "EXTERNAL":
+ extraSasl.AuthMethod = ""
+ return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed"))
+ case "PLAIN":
+ extraSasl.AuthMethod = ""
+ saslPlainData, err := base64.StdEncoding.DecodeString(msg.Params[0])
+ if err != nil {
+ return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed (base64 decoding error)"))
+ }
+ saslPlainSegments := bytes.Split(saslPlainData, []byte{0})
+ if len(saslPlainSegments) != 3 {
+ return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed (not three segments)"))
+ }
+ _ = string(saslPlainSegments[0]) // authzid unused
+ authcid := string(saslPlainSegments[1])
+ passwd := string(saslPlainSegments[2])
+ if authcid == "runxiyu" && passwd == "hunter2" {
+ return client.Send(MakeMsg(self, RPL_SASLSUCCESS, client.Nick, "SASL authentication successful"))
+ }
+ return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed"))
+ default:
+ panic(panicSaslMethod)
+ }
+}
diff --git a/clients.go b/clients.go
index 2719b08..b141cf8 100644
--- a/clients.go
+++ b/clients.go
@@ -16,6 +16,7 @@ type Client struct {
Gecos string
Host string
Caps map[string]struct{}
+ Extra map[string]any
Server Server
State ClientState
}
@@ -55,8 +56,10 @@ func (client *Client) Teardown() {
if !uidToClient.CompareAndDelete(client.UID, client) {
slog.Error("uid inconsistent", "uid", client.UID, "client", client)
}
- if !nickToClient.CompareAndDelete(client.Nick, client) {
- slog.Error("nick inconsistent", "nick", client.Nick, "client", client)
+ if (client.State >= ClientStateRegistered || client.Nick != "*") {
+ if !nickToClient.CompareAndDelete(client.Nick, client) {
+ slog.Error("nick inconsistent", "nick", client.Nick, "client", client)
+ }
}
}
@@ -67,6 +70,7 @@ func NewLocalClient(conn *net.Conn) (*Client, error) {
State: ClientStatePreRegistration,
Nick: "*",
Caps: make(map[string]struct{}),
+ Extra: make(map[string]any),
}
for range 10 {
uid_ := []byte(self.SID)
diff --git a/cmd_cap.go b/cmd_cap.go
index a2af228..65313a4 100644
--- a/cmd_cap.go
+++ b/cmd_cap.go
@@ -42,11 +42,17 @@ func handleClientCap(msg RMsg, client *Client) error {
}
_, ok := Caps[c]
if ok {
- client.Send(MakeMsg(self, "CAP", client.Nick, "ACK", c))
+ err := client.Send(MakeMsg(self, "CAP", client.Nick, "ACK", c))
+ if err != nil {
+ return err
+ }
client.Caps[c] = struct{}{}
// TODO: This is terrible
} else {
- client.Send(MakeMsg(self, "CAP", client.Nick, "NAK", c))
+ err := client.Send(MakeMsg(self, "CAP", client.Nick, "NAK", c))
+ if err != nil {
+ return err
+ }
}
}
case "END":
@@ -55,7 +61,10 @@ func handleClientCap(msg RMsg, client *Client) error {
return nil
}
client.State = ClientStateCapabilitiesFinished
- client.checkRegistration()
+ err := client.checkRegistration()
+ if err != nil {
+ return err
+ }
}
return nil
}
diff --git a/panics.go b/panics.go
new file mode 100644
index 0000000..5bff91e
--- /dev/null
+++ b/panics.go
@@ -0,0 +1,5 @@
+package main
+
+const (
+ panicType = "type error"
+)