From f86cbd83d86473f405ec60a7a78cf28c64150f31 Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Mon, 13 Jan 2025 15:29:39 +0800 Subject: Refactor mx stuff --- main.go | 33 +------- mx_recv.go | 231 ---------------------------------------------------- serve_mx.go | 264 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+), 261 deletions(-) delete mode 100644 mx_recv.go create mode 100644 serve_mx.go diff --git a/main.go b/main.go index 4a4f242..51540f9 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,7 @@ package main import ( - "context" "flag" - "net" "go.lindenii.runxiyu.org/lindenii-common/clog" ) @@ -19,33 +17,8 @@ func main() { clog.Fatal(1, "Error while loading configuration file: "+err.Error()) } - listener, err := net.Listen(config.MX.Net, config.MX.Addr) - if err != nil { - clog.Fatal(1, "MX: Cannot listen: "+err.Error()) - } - defer listener.Close() - clog.Info("MX: Listening via " + config.MX.Net + " on " + config.MX.Addr) + go serve_mx() - for { - conn, err := listener.Accept() - if err != nil { - clog.Error("MX: Cannot accept connection: " + err.Error()) - } - clog.Info("MX: Accepted connection from " + conn.RemoteAddr().String()) - - go func() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - err := handle_mx_recv_conn(ctx, conn) - if err != nil { - if err == err_connection_handler_eof { - clog.Info("MX: Connection for " + conn.RemoteAddr().String() + " closed with EOF") - } else { - clog.Error("MX: Connection handler for " + conn.RemoteAddr().String() + " returned error: " + err.Error()) - } - } else { - clog.Info("MX: Connection for " + conn.RemoteAddr().String() + " closed gracefully") - } - }() - } + deadlock := make(chan struct{}) + deadlock <- struct{}{} } diff --git a/mx_recv.go b/mx_recv.go deleted file mode 100644 index 33b4012..0000000 --- a/mx_recv.go +++ /dev/null @@ -1,231 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "context" - "crypto/tls" - "errors" - "io" - "net" - "slices" - "strings" - - "github.com/jackc/pgx/v5/pgxpool" - "go.lindenii.runxiyu.org/lindenii-common/mailkit" -) - -type server_state_t uint - -const ( - server_state_begin server_state_t = iota - server_state_helo - server_state_mail - server_state_rcpt -) - -type mx_recv_session struct { - buf_conn *bufio.ReadWriter - net_conn net.Conn - tls_conn *tls.Conn - my_server_name string - tls_config *tls.Config - db *pgxpool.Pool - remote_server_name string - current_mail_from string - current_rcpt_to []string - server_state server_state_t -} - -func (session *mx_recv_session) handle(ctx context.Context) error { - session.buf_conn = bufio.NewReadWriter(bufio.NewReader(session.net_conn), bufio.NewWriter(session.net_conn)) - config_consistent_run(func() { - session.my_server_name = config.Server_name - session.tls_config = config._tls_config - session.db = global_db - }) - _, _ = session.buf_conn.WriteString("220 " + session.my_server_name + " " + VERSION + "\r\n") - _ = session.buf_conn.Flush() - session.server_state = server_state_begin - for { - line, err := session.buf_conn.ReadString('\n') - if err != nil { - if err == io.EOF { - return err_connection_handler_eof - } - return err - } - line = strings.TrimSuffix(line, "\n") - line = strings.TrimSuffix(line, "\r") - cmd_end := strings.IndexByte(line, ' ') - var param_start int - if cmd_end == -1 { - cmd_end = len(line) - param_start = len(line) - } else { - param_start = cmd_end + 1 - } - cmd := strings.ToUpper(line[:cmd_end]) - param := line[param_start:] - switch_cmd: - switch cmd { - case "STARTTLS": - if param != "" { - _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax error (no parameters allowed)\r\n") - _ = session.buf_conn.Flush() - break - } - if session.tls_conn != nil { - _, _ = session.buf_conn.WriteString("554 5.5.1 Error: TLS already active\r\n") - _ = session.buf_conn.Flush() - break - } - _, _ = session.buf_conn.WriteString("220 2.0.0 Ready to start TLS\r\n") - _ = session.buf_conn.Flush() - session.tls_conn = tls.Server(session.net_conn, session.tls_config) - session.buf_conn = bufio.NewReadWriter(bufio.NewReader(session.tls_conn), bufio.NewWriter(session.tls_conn)) - session.server_state = server_state_begin - session.current_mail_from = "" - session.current_rcpt_to = []string{""} - case "HELO": - if param == "" { // TODO: actually validate the hostname - _, _ = session.buf_conn.WriteString("501 Syntax: HELO hostname\r\n") - _ = session.buf_conn.Flush() - break - } - session.remote_server_name = param - _ = session.remote_server_name // TODO - session.server_state = server_state_helo - _, _ = session.buf_conn.WriteString("250 " + session.my_server_name + "\r\n") - _ = session.buf_conn.Flush() - case "MAIL": - switch session.server_state { - case server_state_begin: - _, _ = session.buf_conn.WriteString("503 5.5.1 Error: send HELO/EHLO first\r\n") - _ = session.buf_conn.Flush() - break switch_cmd - case server_state_helo: - break - case server_state_mail: - _, _ = session.buf_conn.WriteString("503 5.5.1 Error: nested MAIL command\r\n") - _ = session.buf_conn.Flush() - break switch_cmd - } - if len(param) <= len("FROM:") || strings.ToUpper(param[:len("FROM:")]) != "FROM:" { - _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax: MAIL FROM:
\r\n") - _ = session.buf_conn.Flush() - break - } - session.current_mail_from = param[len("FROM:"):] - session.current_rcpt_to = []string{} - session.server_state = server_state_mail - _, _ = session.buf_conn.WriteString("250 2.1.0 Ok\r\n") - _ = session.buf_conn.Flush() - // TODO: Address validation - case "RCPT": - if session.server_state != server_state_mail && session.server_state != server_state_rcpt { - _, _ = session.buf_conn.WriteString("503 5.5.1 Error: need MAIL command\r\n") - _ = session.buf_conn.Flush() - break - } - if len(param) <= len("TO:") || strings.ToUpper(param[:len("TO:")]) != "TO:" { - _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax: RCPT TO:
\r\n") - _ = session.buf_conn.Flush() - break - } - recipient_address, _, _ := mailkit.Strip_angle_brackets(param[len("TO:"):]) - var count int - err := session.db.QueryRow(ctx, "SELECT COUNT (*) FROM addresses WHERE address = $1", recipient_address).Scan(&count) - if err != nil { - _, _ = session.buf_conn.WriteString("451 Internal error: " + err.Error() + "\r\n") - _ = session.buf_conn.Flush() - break - } - if count == 0 { - _, _ = session.buf_conn.WriteString("550 5.1.1 Recipient address rejected: Local recipients not found: " + recipient_address + "\r\n") - _ = session.buf_conn.Flush() - break - } - session.current_rcpt_to = append(session.current_rcpt_to, recipient_address) - session.server_state = server_state_rcpt - _, _ = session.buf_conn.WriteString("250 2.1.5 Ok\r\n") - _ = session.buf_conn.Flush() - case "DATA": - if session.server_state != server_state_rcpt { - _, _ = session.buf_conn.WriteString("503 5.5.1 Error: need RCPT command\r\n") - _ = session.buf_conn.Flush() - break - } - _, _ = session.buf_conn.WriteString("354 End data with .\r\n") - _ = session.buf_conn.Flush() - var current_data []byte - for { - tmp, err := session.buf_conn.ReadSlice('\r') - if err != nil { - return err - } - - // session.buf_conn.ReadSlice returns an internal buffer that gets - // overwritten on the next reader operation. So we must - // make a copy; also we have to allocate data_part to - // the correct length because [[builtin.copy]] copies - // min(len(dst), len(src)) items. - data_part := make([]byte, len(tmp)) - copy(data_part, tmp) - - next_four, err := session.buf_conn.Peek(4) - if err != nil { - return err - } - if bytes.Equal(next_four, []byte{'\n', '.', '\r', '\n'}) { - current_data = slices.Concat(current_data, data_part[:len(data_part)-1]) - break - } - current_data = slices.Concat(current_data, data_part) - } - _, err := session.buf_conn.Discard(4) - if err != nil { - return err - } - err = deliver_local(ctx, session.db, session.current_mail_from, session.current_rcpt_to, current_data, session.current_rcpt_to) - var err_local_recipients_not_found *err_local_recipients_not_found_t - switch { - case errors.As(err, &err_local_recipients_not_found): - _, _ = session.buf_conn.WriteString("550 5.1.1 Recipient address rejected: " + err_local_recipients_not_found.Error() + "\r\n") - case err != nil: - _, _ = session.buf_conn.WriteString("500 2.0.0 Error: " + err.Error() + "\r\n") - default: - _, _ = session.buf_conn.WriteString("250 2.0.0 Ok: Accepted\r\n") - } - _ = session.buf_conn.Flush() - session.server_state = server_state_helo - case "QUIT": - _, _ = session.buf_conn.WriteString("221 2.0.0 Bye\r\n") - _ = session.buf_conn.Flush() - if session.tls_conn != nil { - session.tls_conn.Close() - } - session.net_conn.Close() - return nil - case "NOOP": - _, _ = session.buf_conn.WriteString("250 2.0.0 Ok\r\n") - _ = session.buf_conn.Flush() - case "RSET": - if session.server_state != server_state_begin { - session.server_state = server_state_helo - } - _, _ = session.buf_conn.WriteString("250 2.0.0 Ok\r\n") - _ = session.buf_conn.Flush() - default: - _, _ = session.buf_conn.WriteString("500 5.5.2 Error: command not recognized\r\n") - _ = session.buf_conn.Flush() - } - } -} - -func handle_mx_recv_conn(ctx context.Context, net_conn net.Conn) error { - session := mx_recv_session{ - net_conn: net_conn, - } - return session.handle(ctx) -} diff --git a/serve_mx.go b/serve_mx.go new file mode 100644 index 0000000..f3c1527 --- /dev/null +++ b/serve_mx.go @@ -0,0 +1,264 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "errors" + "io" + "net" + "slices" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" + "go.lindenii.runxiyu.org/lindenii-common/mailkit" + "go.lindenii.runxiyu.org/lindenii-common/clog" +) + +type server_state_t uint + +const ( + server_state_begin server_state_t = iota + server_state_helo + server_state_mail + server_state_rcpt +) + +type mx_recv_session struct { + buf_conn *bufio.ReadWriter + net_conn net.Conn + tls_conn *tls.Conn + my_server_name string + tls_config *tls.Config + db *pgxpool.Pool + remote_server_name string + current_mail_from string + current_rcpt_to []string + server_state server_state_t +} + +func (session *mx_recv_session) handle(ctx context.Context) error { + session.buf_conn = bufio.NewReadWriter(bufio.NewReader(session.net_conn), bufio.NewWriter(session.net_conn)) + config_consistent_run(func() { + session.my_server_name = config.Server_name + session.tls_config = config._tls_config + session.db = global_db + }) + _, _ = session.buf_conn.WriteString("220 " + session.my_server_name + " " + VERSION + "\r\n") + _ = session.buf_conn.Flush() + session.server_state = server_state_begin + for { + line, err := session.buf_conn.ReadString('\n') + if err != nil { + if err == io.EOF { + return err_connection_handler_eof + } + return err + } + line = strings.TrimSuffix(line, "\n") + line = strings.TrimSuffix(line, "\r") + cmd_end := strings.IndexByte(line, ' ') + var param_start int + if cmd_end == -1 { + cmd_end = len(line) + param_start = len(line) + } else { + param_start = cmd_end + 1 + } + cmd := strings.ToUpper(line[:cmd_end]) + param := line[param_start:] + switch_cmd: + switch cmd { + case "STARTTLS": + if param != "" { + _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax error (no parameters allowed)\r\n") + _ = session.buf_conn.Flush() + break + } + if session.tls_conn != nil { + _, _ = session.buf_conn.WriteString("554 5.5.1 Error: TLS already active\r\n") + _ = session.buf_conn.Flush() + break + } + _, _ = session.buf_conn.WriteString("220 2.0.0 Ready to start TLS\r\n") + _ = session.buf_conn.Flush() + session.tls_conn = tls.Server(session.net_conn, session.tls_config) + session.buf_conn = bufio.NewReadWriter(bufio.NewReader(session.tls_conn), bufio.NewWriter(session.tls_conn)) + session.server_state = server_state_begin + session.current_mail_from = "" + session.current_rcpt_to = []string{""} + case "HELO": + if param == "" { // TODO: actually validate the hostname + _, _ = session.buf_conn.WriteString("501 Syntax: HELO hostname\r\n") + _ = session.buf_conn.Flush() + break + } + session.remote_server_name = param + _ = session.remote_server_name // TODO + session.server_state = server_state_helo + _, _ = session.buf_conn.WriteString("250 " + session.my_server_name + "\r\n") + _ = session.buf_conn.Flush() + case "MAIL": + switch session.server_state { + case server_state_begin: + _, _ = session.buf_conn.WriteString("503 5.5.1 Error: send HELO/EHLO first\r\n") + _ = session.buf_conn.Flush() + break switch_cmd + case server_state_helo: + break + case server_state_mail: + _, _ = session.buf_conn.WriteString("503 5.5.1 Error: nested MAIL command\r\n") + _ = session.buf_conn.Flush() + break switch_cmd + } + if len(param) <= len("FROM:") || strings.ToUpper(param[:len("FROM:")]) != "FROM:" { + _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax: MAIL FROM:
\r\n") + _ = session.buf_conn.Flush() + break + } + session.current_mail_from = param[len("FROM:"):] + session.current_rcpt_to = []string{} + session.server_state = server_state_mail + _, _ = session.buf_conn.WriteString("250 2.1.0 Ok\r\n") + _ = session.buf_conn.Flush() + // TODO: Address validation + case "RCPT": + if session.server_state != server_state_mail && session.server_state != server_state_rcpt { + _, _ = session.buf_conn.WriteString("503 5.5.1 Error: need MAIL command\r\n") + _ = session.buf_conn.Flush() + break + } + if len(param) <= len("TO:") || strings.ToUpper(param[:len("TO:")]) != "TO:" { + _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax: RCPT TO:
\r\n") + _ = session.buf_conn.Flush() + break + } + recipient_address, _, _ := mailkit.Strip_angle_brackets(param[len("TO:"):]) + var count int + err := session.db.QueryRow(ctx, "SELECT COUNT (*) FROM addresses WHERE address = $1", recipient_address).Scan(&count) + if err != nil { + _, _ = session.buf_conn.WriteString("451 Internal error: " + err.Error() + "\r\n") + _ = session.buf_conn.Flush() + break + } + if count == 0 { + _, _ = session.buf_conn.WriteString("550 5.1.1 Recipient address rejected: Local recipients not found: " + recipient_address + "\r\n") + _ = session.buf_conn.Flush() + break + } + session.current_rcpt_to = append(session.current_rcpt_to, recipient_address) + session.server_state = server_state_rcpt + _, _ = session.buf_conn.WriteString("250 2.1.5 Ok\r\n") + _ = session.buf_conn.Flush() + case "DATA": + if session.server_state != server_state_rcpt { + _, _ = session.buf_conn.WriteString("503 5.5.1 Error: need RCPT command\r\n") + _ = session.buf_conn.Flush() + break + } + _, _ = session.buf_conn.WriteString("354 End data with .\r\n") + _ = session.buf_conn.Flush() + var current_data []byte + for { + tmp, err := session.buf_conn.ReadSlice('\r') + if err != nil { + return err + } + + // session.buf_conn.ReadSlice returns an internal buffer that gets + // overwritten on the next reader operation. So we must + // make a copy; also we have to allocate data_part to + // the correct length because [[builtin.copy]] copies + // min(len(dst), len(src)) items. + data_part := make([]byte, len(tmp)) + copy(data_part, tmp) + + next_four, err := session.buf_conn.Peek(4) + if err != nil { + return err + } + if bytes.Equal(next_four, []byte{'\n', '.', '\r', '\n'}) { + current_data = slices.Concat(current_data, data_part[:len(data_part)-1]) + break + } + current_data = slices.Concat(current_data, data_part) + } + _, err := session.buf_conn.Discard(4) + if err != nil { + return err + } + err = deliver_local(ctx, session.db, session.current_mail_from, session.current_rcpt_to, current_data, session.current_rcpt_to) + var err_local_recipients_not_found *err_local_recipients_not_found_t + switch { + case errors.As(err, &err_local_recipients_not_found): + _, _ = session.buf_conn.WriteString("550 5.1.1 Recipient address rejected: " + err_local_recipients_not_found.Error() + "\r\n") + case err != nil: + _, _ = session.buf_conn.WriteString("500 2.0.0 Error: " + err.Error() + "\r\n") + default: + _, _ = session.buf_conn.WriteString("250 2.0.0 Ok: Accepted\r\n") + } + _ = session.buf_conn.Flush() + session.server_state = server_state_helo + case "QUIT": + _, _ = session.buf_conn.WriteString("221 2.0.0 Bye\r\n") + _ = session.buf_conn.Flush() + if session.tls_conn != nil { + session.tls_conn.Close() + } + session.net_conn.Close() + return nil + case "NOOP": + _, _ = session.buf_conn.WriteString("250 2.0.0 Ok\r\n") + _ = session.buf_conn.Flush() + case "RSET": + if session.server_state != server_state_begin { + session.server_state = server_state_helo + } + _, _ = session.buf_conn.WriteString("250 2.0.0 Ok\r\n") + _ = session.buf_conn.Flush() + default: + _, _ = session.buf_conn.WriteString("500 5.5.2 Error: command not recognized\r\n") + _ = session.buf_conn.Flush() + } + } +} + +func handle_mx_recv_conn(ctx context.Context, net_conn net.Conn) error { + session := mx_recv_session{ + net_conn: net_conn, + } + return session.handle(ctx) +} + +func serve_mx() { + listener, err := net.Listen(config.MX.Net, config.MX.Addr) + if err != nil { + clog.Fatal(1, "MX: Cannot listen: "+err.Error()) + } + defer listener.Close() + clog.Info("MX: Listening via " + config.MX.Net + " on " + config.MX.Addr) + + for { + conn, err := listener.Accept() + if err != nil { + clog.Error("MX: Cannot accept connection: " + err.Error()) + } + clog.Info("MX: Accepted connection from " + conn.RemoteAddr().String()) + + go func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := handle_mx_recv_conn(ctx, conn) + if err != nil { + if err == err_connection_handler_eof { + clog.Info("MX: Connection for " + conn.RemoteAddr().String() + " closed with EOF") + } else { + clog.Error("MX: Connection handler for " + conn.RemoteAddr().String() + " returned error: " + err.Error()) + } + } else { + clog.Info("MX: Connection for " + conn.RemoteAddr().String() + " closed gracefully") + } + }() + } +} -- cgit v1.2.3