From 9e1c9c5f43e2f0cf2dd1fa29b1ba3512dbbaeb7e Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Sun, 8 Dec 2024 15:17:50 +0800 Subject: Implement SASL stub --- cap_sasl.go | 86 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ clients.go | 8 ++++-- cmd_cap.go | 15 ++++++++--- panics.go | 5 ++++ 4 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 cap_sasl.go create mode 100644 panics.go diff --git a/cap_sasl.go b/cap_sasl.go new file mode 100644 index 0000000..7044efb --- /dev/null +++ b/cap_sasl.go @@ -0,0 +1,86 @@ +package main + +import ( + "encoding/base64" + "bytes" +) + +type ExtraSasl struct { + AuthMethod string +} + +const ( + RPL_LOGGEDIN = "900" + RPL_LOGGEDOUT = "901" + ERR_NICKLOCKED = "902" + RPL_SASLSUCCESS = "903" + ERR_SASLFAIL = "904" + ERR_SASLTOOLONG = "905" + ERR_SASLABORTED = "906" + ERR_SASLALREADY = "907" + RPL_SASLMECHS = "908" +) + +const ( + panicSaslMethod = "stored illegal SASL method" +) + +func init() { + Caps["sasl"] = "PLAIN,EXTERNAL" + CommandHandlers["AUTHENTICATE"] = handleClientAuthenticate +} + +func handleClientAuthenticate(msg RMsg, client *Client) error { + _, ok := client.Caps["sasl"] + if !ok { + return client.Send(MakeMsg(self, "TODO", "you're trying to sasl without requesting for it")) + } + + if len(msg.Params) < 1 { + return client.Send(MakeMsg(self, ERR_NEEDMOREPARAMS, "AUTHENTICATE", "Not enough parameters")) + } + + extraSasl_, ok := client.Extra["sasl"] + if !ok { + client.Extra["sasl"] = &ExtraSasl{} + extraSasl_ = client.Extra["sasl"] + } + extraSasl, ok := extraSasl_.(*ExtraSasl) + if !ok { + panic(panicType) + } + + switch extraSasl.AuthMethod { + case "": + if msg.Params[0] != "PLAIN" && msg.Params[0] != "EXTERNAL" { + return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed (invalid method)")) + } + extraSasl.AuthMethod = msg.Params[0] + return client.Send(MakeMsg(self, "AUTHENTICATE", "+")) + case "*": // Abort + extraSasl.AuthMethod = "" + return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed (aborted)")) + case "EXTERNAL": + extraSasl.AuthMethod = "" + return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed")) + case "PLAIN": + extraSasl.AuthMethod = "" + saslPlainData, err := base64.StdEncoding.DecodeString(msg.Params[0]) + if err != nil { + return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed (base64 decoding error)")) + } + saslPlainSegments := bytes.Split(saslPlainData, []byte{0}) + if len(saslPlainSegments) != 3 { + return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed (not three segments)")) + } + _ = string(saslPlainSegments[0]) // authzid unused + authcid := string(saslPlainSegments[1]) + passwd := string(saslPlainSegments[2]) + if authcid == "runxiyu" && passwd == "hunter2" { + return client.Send(MakeMsg(self, RPL_SASLSUCCESS, client.Nick, "SASL authentication successful")) + } + return client.Send(MakeMsg(self, ERR_SASLFAIL, client.Nick, "SASL authentication failed")) + default: + panic(panicSaslMethod) + } +} diff --git a/clients.go b/clients.go index 2719b08..b141cf8 100644 --- a/clients.go +++ b/clients.go @@ -16,6 +16,7 @@ type Client struct { Gecos string Host string Caps map[string]struct{} + Extra map[string]any Server Server State ClientState } @@ -55,8 +56,10 @@ func (client *Client) Teardown() { if !uidToClient.CompareAndDelete(client.UID, client) { slog.Error("uid inconsistent", "uid", client.UID, "client", client) } - if !nickToClient.CompareAndDelete(client.Nick, client) { - slog.Error("nick inconsistent", "nick", client.Nick, "client", client) + if (client.State >= ClientStateRegistered || client.Nick != "*") { + if !nickToClient.CompareAndDelete(client.Nick, client) { + slog.Error("nick inconsistent", "nick", client.Nick, "client", client) + } } } @@ -67,6 +70,7 @@ func NewLocalClient(conn *net.Conn) (*Client, error) { State: ClientStatePreRegistration, Nick: "*", Caps: make(map[string]struct{}), + Extra: make(map[string]any), } for range 10 { uid_ := []byte(self.SID) diff --git a/cmd_cap.go b/cmd_cap.go index a2af228..65313a4 100644 --- a/cmd_cap.go +++ b/cmd_cap.go @@ -42,11 +42,17 @@ func handleClientCap(msg RMsg, client *Client) error { } _, ok := Caps[c] if ok { - client.Send(MakeMsg(self, "CAP", client.Nick, "ACK", c)) + err := client.Send(MakeMsg(self, "CAP", client.Nick, "ACK", c)) + if err != nil { + return err + } client.Caps[c] = struct{}{} // TODO: This is terrible } else { - client.Send(MakeMsg(self, "CAP", client.Nick, "NAK", c)) + err := client.Send(MakeMsg(self, "CAP", client.Nick, "NAK", c)) + if err != nil { + return err + } } } case "END": @@ -55,7 +61,10 @@ func handleClientCap(msg RMsg, client *Client) error { return nil } client.State = ClientStateCapabilitiesFinished - client.checkRegistration() + err := client.checkRegistration() + if err != nil { + return err + } } return nil } diff --git a/panics.go b/panics.go new file mode 100644 index 0000000..5bff91e --- /dev/null +++ b/panics.go @@ -0,0 +1,5 @@ +package main + +const ( + panicType = "type error" +) -- cgit v1.2.3