package main
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"net"
"slices"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
"go.lindenii.runxiyu.org/lindenii-common/clog"
"go.lindenii.runxiyu.org/lindenii-common/mailkit"
)
type mx_server_state_t uint
const (
mx_server_state_begin mx_server_state_t = iota
mx_server_state_helo
mx_server_state_mail
mx_server_state_rcpt
)
type mx_recv_session struct {
buf_conn *bufio.ReadWriter
net_conn net.Conn
tls_conn *tls.Conn
my_server_name string
tls_config *tls.Config
db *pgxpool.Pool
remote_server_name string
current_mail_from string
current_rcpt_to []string
server_state mx_server_state_t
}
func (session *mx_recv_session) handle(ctx context.Context) error {
session.buf_conn = bufio.NewReadWriter(bufio.NewReader(session.net_conn), bufio.NewWriter(session.net_conn))
config_consistent_run(func() {
session.my_server_name = config.Server_name
session.tls_config = config._tls_config
session.db = global_db
})
_, _ = session.buf_conn.WriteString("220 " + session.my_server_name + " " + VERSION + "\r\n")
_ = session.buf_conn.Flush()
session.server_state = mx_server_state_begin
for {
var cmd, param string
{
line, err := session.buf_conn.ReadString('\n')
if err != nil {
if err == io.EOF {
return err_connection_handler_eof
}
return err
}
line = strings.TrimSuffix(line, "\n")
line = strings.TrimSuffix(line, "\r")
cmd_end := strings.IndexByte(line, ' ')
var param_start int
if cmd_end == -1 {
cmd_end = len(line)
param_start = len(line)
} else {
param_start = cmd_end + 1
}
cmd = strings.ToUpper(line[:cmd_end])
param = line[param_start:]
}
switch_cmd:
switch cmd {
case "STARTTLS":
if param != "" {
_, _ = session.buf_conn.WriteString("501 5.5.4 Syntax error (no parameters allowed)\r\n")
_ = session.buf_conn.Flush()
break
}
if session.tls_conn != nil {
_, _ = session.buf_conn.WriteString("554 5.5.1 Error: TLS already active\r\n")
_ = session.buf_conn.Flush()
break
}
_, _ = session.buf_conn.WriteString("220 2.0.0 Ready to start TLS\r\n")
_ = session.buf_conn.Flush()
session.tls_conn = tls.Server(session.net_conn, session.tls_config)
session.buf_conn = bufio.NewReadWriter(bufio.NewReader(session.tls_conn), bufio.NewWriter(session.tls_conn))
session.server_state = mx_server_state_begin
session.current_mail_from = ""
session.current_rcpt_to = []string{""}
case "HELO":
if param == "" { // TODO: actually validate the hostname
_, _ = session.buf_conn.WriteString("501 Syntax: HELO hostname\r\n")
_ = session.buf_conn.Flush()
break
}
session.remote_server_name = param
_ = session.remote_server_name // TODO
session.server_state = mx_server_state_helo
_, _ = session.buf_conn.WriteString("250 " + session.my_server_name + "\r\n")
_ = session.buf_conn.Flush()
case "MAIL":
switch session.server_state {
case mx_server_state_begin:
_, _ = session.buf_conn.WriteString("503 5.5.1 Error: send HELO/EHLO first\r\n")
_ = session.buf_conn.Flush()
break switch_cmd
case mx_server_state_helo:
break
case mx_server_state_mail:
_, _ = session.buf_conn.WriteString("503 5.5.1 Error: nested MAIL command\r\n")
_ = session.buf_conn.Flush()
break switch_cmd
}
if len(param) <= len("FROM:") || strings.ToUpper(param[:len("FROM:")]) != "FROM:" {
_, _ = session.buf_conn.WriteString("501 5.5.4 Syntax: MAIL FROM:<address>\r\n")
_ = session.buf_conn.Flush()
break
}
session.current_mail_from = param[len("FROM:"):]
session.current_rcpt_to = []string{}
session.server_state = mx_server_state_mail
_, _ = session.buf_conn.WriteString("250 2.1.0 Ok\r\n")
_ = session.buf_conn.Flush()
// TODO: Address validation
case "RCPT":
if session.server_state != mx_server_state_mail && session.server_state != mx_server_state_rcpt {
_, _ = session.buf_conn.WriteString("503 5.5.1 Error: need MAIL command\r\n")
_ = session.buf_conn.Flush()
break
}
if len(param) <= len("TO:") || strings.ToUpper(param[:len("TO:")]) != "TO:" {
_, _ = session.buf_conn.WriteString("501 5.5.4 Syntax: RCPT TO:<address>\r\n")
_ = session.buf_conn.Flush()
break
}
recipient_address, _, _ := mailkit.Strip_angle_brackets(param[len("TO:"):])
var count int
err := session.db.QueryRow(ctx, "SELECT COUNT (*) FROM addresses WHERE address = $1", recipient_address).Scan(&count)
if err != nil {
_, _ = session.buf_conn.WriteString("451 Internal error: " + err.Error() + "\r\n")
_ = session.buf_conn.Flush()
break
}
if count == 0 {
_, _ = session.buf_conn.WriteString("550 5.1.1 Recipient address rejected: Local recipients not found: " + recipient_address + "\r\n")
_ = session.buf_conn.Flush()
break
}
session.current_rcpt_to = append(session.current_rcpt_to, recipient_address)
session.server_state = mx_server_state_rcpt
_, _ = session.buf_conn.WriteString("250 2.1.5 Ok\r\n")
_ = session.buf_conn.Flush()
case "DATA":
if session.server_state != mx_server_state_rcpt {
_, _ = session.buf_conn.WriteString("503 5.5.1 Error: need RCPT command\r\n")
_ = session.buf_conn.Flush()
break
}
_, _ = session.buf_conn.WriteString("354 End data with <CR><LF>.<CR><LF>\r\n")
_ = session.buf_conn.Flush()
var current_data []byte
for {
tmp, err := session.buf_conn.ReadSlice('\r')
if err != nil {
return err
}
// session.buf_conn.ReadSlice returns an internal buffer that gets
// overwritten on the next reader operation. So we must
// make a copy; also we have to allocate data_part to
// the correct length because [[builtin.copy]] copies
// min(len(dst), len(src)) items.
data_part := make([]byte, len(tmp))
copy(data_part, tmp)
next_four, err := session.buf_conn.Peek(4)
if err != nil {
return err
}
if bytes.Equal(next_four, []byte{'\n', '.', '\r', '\n'}) {
current_data = slices.Concat(current_data, data_part[:len(data_part)-1])
break
}
current_data = slices.Concat(current_data, data_part)
}
_, err := session.buf_conn.Discard(4)
if err != nil {
return err
}
err = deliver_local(ctx, session.db, session.current_mail_from, session.current_rcpt_to, current_data, session.current_rcpt_to)
var err_local_recipients_not_found *err_local_recipients_not_found_t
switch {
case errors.As(err, &err_local_recipients_not_found):
_, _ = session.buf_conn.WriteString("550 5.1.1 Recipient address rejected: " + err_local_recipients_not_found.Error() + "\r\n")
case err != nil:
_, _ = session.buf_conn.WriteString("500 2.0.0 Error: " + err.Error() + "\r\n")
default:
_, _ = session.buf_conn.WriteString("250 2.0.0 Ok: Accepted\r\n")
}
_ = session.buf_conn.Flush()
session.server_state = mx_server_state_helo
case "QUIT":
_, _ = session.buf_conn.WriteString("221 2.0.0 Bye\r\n")
_ = session.buf_conn.Flush()
if session.tls_conn != nil {
session.tls_conn.Close()
}
session.net_conn.Close()
return nil
case "NOOP":
_, _ = session.buf_conn.WriteString("250 2.0.0 Ok\r\n")
_ = session.buf_conn.Flush()
case "RSET":
if session.server_state != mx_server_state_begin {
session.server_state = mx_server_state_helo
}
_, _ = session.buf_conn.WriteString("250 2.0.0 Ok\r\n")
_ = session.buf_conn.Flush()
default:
_, _ = session.buf_conn.WriteString("500 5.5.2 Error: command not recognized\r\n")
_ = session.buf_conn.Flush()
}
}
}
func mx_new_session(ctx context.Context, net_conn net.Conn) error {
session := mx_recv_session{
net_conn: net_conn,
}
return session.handle(ctx)
}
func serve_mx() {
var mx_net, mx_addr string
config_consistent_run(func() {
mx_net = config.MX.Net
mx_addr = config.MX.Addr
})
listener, err := net.Listen(mx_net, mx_addr)
if err != nil {
clog.Fatal(1, "MX: Cannot listen: "+err.Error())
}
defer listener.Close()
clog.Info("MX: Listening via " + mx_net + " on " + mx_addr)
for {
conn, err := listener.Accept()
if err != nil {
clog.Error("MX: Cannot accept connection: " + err.Error())
}
clog.Info("MX: Accepted connection from " + conn.RemoteAddr().String())
go func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := mx_new_session(ctx, conn)
if err != nil {
if err == err_connection_handler_eof {
clog.Info("MX: Connection for " + conn.RemoteAddr().String() + " closed with EOF")
} else {
clog.Error("MX: Connection handler for " + conn.RemoteAddr().String() + " returned error: " + err.Error())
}
} else {
clog.Info("MX: Connection for " + conn.RemoteAddr().String() + " closed gracefully")
}
}()
}
}