package main import ( "bufio" "bytes" "context" "crypto/tls" "errors" "io" "net" "slices" "strings" "github.com/jackc/pgx/v5/pgxpool" "go.lindenii.runxiyu.org/lindenii-common/clog" "go.lindenii.runxiyu.org/lindenii-common/mailkit" ) type server_state_t uint const ( server_state_begin server_state_t = iota server_state_helo server_state_mail server_state_rcpt ) type mx_recv_session struct { buf_conn *bufio.ReadWriter net_conn net.Conn tls_conn *tls.Conn my_server_name string tls_config *tls.Config db *pgxpool.Pool remote_server_name string current_mail_from string current_rcpt_to []string server_state server_state_t } func (session *mx_recv_session) handle(ctx context.Context) error { session.buf_conn = bufio.NewReadWriter(bufio.NewReader(session.net_conn), bufio.NewWriter(session.net_conn)) config_consistent_run(func() { session.my_server_name = config.Server_name session.tls_config = config._tls_config session.db = global_db }) _, _ = session.buf_conn.WriteString("220 " + session.my_server_name + " " + VERSION + "\r\n") _ = session.buf_conn.Flush() session.server_state = server_state_begin for { line, err := session.buf_conn.ReadString('\n') if err != nil { if err == io.EOF { return err_connection_handler_eof } return err } line = strings.TrimSuffix(line, "\n") line = strings.TrimSuffix(line, "\r") cmd_end := strings.IndexByte(line, ' ') var param_start int if cmd_end == -1 { cmd_end = len(line) param_start = len(line) } else { param_start = cmd_end + 1 } cmd := strings.ToUpper(line[:cmd_end]) param := line[param_start:] switch_cmd: switch cmd { case "STARTTLS": if param != "" { _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax error (no parameters allowed)\r\n") _ = session.buf_conn.Flush() break } if session.tls_conn != nil { _, _ = session.buf_conn.WriteString("554 5.5.1 Error: TLS already active\r\n") _ = session.buf_conn.Flush() break } _, _ = session.buf_conn.WriteString("220 2.0.0 Ready to start TLS\r\n") _ = session.buf_conn.Flush() session.tls_conn = tls.Server(session.net_conn, session.tls_config) session.buf_conn = bufio.NewReadWriter(bufio.NewReader(session.tls_conn), bufio.NewWriter(session.tls_conn)) session.server_state = server_state_begin session.current_mail_from = "" session.current_rcpt_to = []string{""} case "HELO": if param == "" { // TODO: actually validate the hostname _, _ = session.buf_conn.WriteString("501 Syntax: HELO hostname\r\n") _ = session.buf_conn.Flush() break } session.remote_server_name = param _ = session.remote_server_name // TODO session.server_state = server_state_helo _, _ = session.buf_conn.WriteString("250 " + session.my_server_name + "\r\n") _ = session.buf_conn.Flush() case "MAIL": switch session.server_state { case server_state_begin: _, _ = session.buf_conn.WriteString("503 5.5.1 Error: send HELO/EHLO first\r\n") _ = session.buf_conn.Flush() break switch_cmd case server_state_helo: break case server_state_mail: _, _ = session.buf_conn.WriteString("503 5.5.1 Error: nested MAIL command\r\n") _ = session.buf_conn.Flush() break switch_cmd } if len(param) <= len("FROM:") || strings.ToUpper(param[:len("FROM:")]) != "FROM:" { _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax: MAIL FROM:
\r\n") _ = session.buf_conn.Flush() break } session.current_mail_from = param[len("FROM:"):] session.current_rcpt_to = []string{} session.server_state = server_state_mail _, _ = session.buf_conn.WriteString("250 2.1.0 Ok\r\n") _ = session.buf_conn.Flush() // TODO: Address validation case "RCPT": if session.server_state != server_state_mail && session.server_state != server_state_rcpt { _, _ = session.buf_conn.WriteString("503 5.5.1 Error: need MAIL command\r\n") _ = session.buf_conn.Flush() break } if len(param) <= len("TO:") || strings.ToUpper(param[:len("TO:")]) != "TO:" { _, _ = session.buf_conn.WriteString("501 5.5.4 Syntax: RCPT TO:
\r\n") _ = session.buf_conn.Flush() break } recipient_address, _, _ := mailkit.Strip_angle_brackets(param[len("TO:"):]) var count int err := session.db.QueryRow(ctx, "SELECT COUNT (*) FROM addresses WHERE address = $1", recipient_address).Scan(&count) if err != nil { _, _ = session.buf_conn.WriteString("451 Internal error: " + err.Error() + "\r\n") _ = session.buf_conn.Flush() break } if count == 0 { _, _ = session.buf_conn.WriteString("550 5.1.1 Recipient address rejected: Local recipients not found: " + recipient_address + "\r\n") _ = session.buf_conn.Flush() break } session.current_rcpt_to = append(session.current_rcpt_to, recipient_address) session.server_state = server_state_rcpt _, _ = session.buf_conn.WriteString("250 2.1.5 Ok\r\n") _ = session.buf_conn.Flush() case "DATA": if session.server_state != server_state_rcpt { _, _ = session.buf_conn.WriteString("503 5.5.1 Error: need RCPT command\r\n") _ = session.buf_conn.Flush() break } _, _ = session.buf_conn.WriteString("354 End data with .\r\n") _ = session.buf_conn.Flush() var current_data []byte for { tmp, err := session.buf_conn.ReadSlice('\r') if err != nil { return err } // session.buf_conn.ReadSlice returns an internal buffer that gets // overwritten on the next reader operation. So we must // make a copy; also we have to allocate data_part to // the correct length because [[builtin.copy]] copies // min(len(dst), len(src)) items. data_part := make([]byte, len(tmp)) copy(data_part, tmp) next_four, err := session.buf_conn.Peek(4) if err != nil { return err } if bytes.Equal(next_four, []byte{'\n', '.', '\r', '\n'}) { current_data = slices.Concat(current_data, data_part[:len(data_part)-1]) break } current_data = slices.Concat(current_data, data_part) } _, err := session.buf_conn.Discard(4) if err != nil { return err } err = deliver_local(ctx, session.db, session.current_mail_from, session.current_rcpt_to, current_data, session.current_rcpt_to) var err_local_recipients_not_found *err_local_recipients_not_found_t switch { case errors.As(err, &err_local_recipients_not_found): _, _ = session.buf_conn.WriteString("550 5.1.1 Recipient address rejected: " + err_local_recipients_not_found.Error() + "\r\n") case err != nil: _, _ = session.buf_conn.WriteString("500 2.0.0 Error: " + err.Error() + "\r\n") default: _, _ = session.buf_conn.WriteString("250 2.0.0 Ok: Accepted\r\n") } _ = session.buf_conn.Flush() session.server_state = server_state_helo case "QUIT": _, _ = session.buf_conn.WriteString("221 2.0.0 Bye\r\n") _ = session.buf_conn.Flush() if session.tls_conn != nil { session.tls_conn.Close() } session.net_conn.Close() return nil case "NOOP": _, _ = session.buf_conn.WriteString("250 2.0.0 Ok\r\n") _ = session.buf_conn.Flush() case "RSET": if session.server_state != server_state_begin { session.server_state = server_state_helo } _, _ = session.buf_conn.WriteString("250 2.0.0 Ok\r\n") _ = session.buf_conn.Flush() default: _, _ = session.buf_conn.WriteString("500 5.5.2 Error: command not recognized\r\n") _ = session.buf_conn.Flush() } } } func mx_new_session(ctx context.Context, net_conn net.Conn) error { session := mx_recv_session{ net_conn: net_conn, } return session.handle(ctx) } func serve_mx() { var mx_net, mx_addr string config_consistent_run(func() { mx_net = config.MX.Net mx_addr = config.MX.Addr }) listener, err := net.Listen(mx_net, mx_addr) if err != nil { clog.Fatal(1, "MX: Cannot listen: "+err.Error()) } defer listener.Close() clog.Info("MX: Listening via " + mx_net + " on " + mx_addr) for { conn, err := listener.Accept() if err != nil { clog.Error("MX: Cannot accept connection: " + err.Error()) } clog.Info("MX: Accepted connection from " + conn.RemoteAddr().String()) go func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() err := mx_new_session(ctx, conn) if err != nil { if err == err_connection_handler_eof { clog.Info("MX: Connection for " + conn.RemoteAddr().String() + " closed with EOF") } else { clog.Error("MX: Connection handler for " + conn.RemoteAddr().String() + " returned error: " + err.Error()) } } else { clog.Info("MX: Connection for " + conn.RemoteAddr().String() + " closed gracefully") } }() } }