internal/honeypot/smtp/smtp.go

package smtp

import (
	"bufio"
	"bytes"
	"context"
	"crypto/tls"
	"encoding/base64"
	"fmt"
	"log/slog"
	"net"
	"regexp"
	"strings"
	"sync"
	"time"

	"honeypot/internal/database"
	"honeypot/internal/honeypot"
	"honeypot/internal/logger"
	tlscert "honeypot/internal/tls"
	"honeypot/internal/types"
	"honeypot/internal/utils"
)

const (
	HoneypotType             = types.HoneypotTypeSMTP
	HoneypotLabel            = "SMTP"
	DefaultConnectionTimeout = 30 * time.Second
	MaxDataSize              = 500 * 1024 // 500 KB
	ServerBanner             = "220 localhost ESMTP ready\r\n"
)

type smtpState int

const (
	stateNew smtpState = iota
	stateGreeted
	stateMail
	stateRcpt
	stateData
)

type Config struct {
	ListenAddr  string
	Ports       []uint16
	SMTPSPorts  []uint16
	Certificate *tls.Certificate
	CertConfig  tlscert.CertConfig
}

type smtpHoneypot struct {
	config    Config
	logger    *slog.Logger
	tlsConfig *tls.Config
}

type smtpSession struct {
	remoteHost string
	remotePort uint16
	dstPort    uint16
	state      smtpState
	commands   []string
	mailFrom   string
	rcptTo     []string
	data       []byte
	tlsEnabled bool
	tlsType    string
}

// New creates a new SMTP honeypot instance.
func New(cfg Config) honeypot.Honeypot {
	logger.RegisterTopNField(string(HoneypotType), "auth_mechanism")
	h := &smtpHoneypot{config: cfg}
	if cfg.Certificate != nil {
		h.tlsConfig = &tls.Config{
			Certificates: []tls.Certificate{*cfg.Certificate},
			MinVersion:   tls.VersionTLS12,
		}
	}
	return h
}

func (h *smtpHoneypot) Name() types.HoneypotType { return HoneypotType }

// Label returns the label of this honeypot.
func (h *smtpHoneypot) Label() string {
	return HoneypotLabel
}

func (h *smtpHoneypot) Start(ctx context.Context, l *slog.Logger) error {
	h.logger = l
	if h.tlsConfig == nil && (len(h.config.Ports) > 0 || len(h.config.SMTPSPorts) > 0) {
		if cert, err := tlscert.GenerateSelfSignedCert(h.config.ListenAddr, h.config.CertConfig, h.logger); err == nil {
			h.tlsConfig = &tls.Config{
				Certificates: []tls.Certificate{*cert},
				MinVersion:   tls.VersionTLS12,
			}
		}
	}

	var wg sync.WaitGroup
	startServer := func(port uint16, useTLS bool) {
		if port == 0 {
			return
		}
		wg.Add(1)
		go func() {
			defer wg.Done()
			h.listenAndServe(ctx, port, useTLS)
		}()
	}

	for _, port := range h.config.Ports {
		startServer(port, false)
	}
	for _, port := range h.config.SMTPSPorts {
		startServer(port, true)
	}
	wg.Wait()
	return nil
}

func (h *smtpHoneypot) listenAndServe(ctx context.Context, port uint16, useTLS bool) {
	addr := utils.BuildAddress(h.config.ListenAddr, port)
	var ln net.Listener
	var err error

	if useTLS {
		if h.tlsConfig == nil {
			return
		}
		ln, err = tls.Listen("tcp", addr, h.tlsConfig)
	} else {
		ln, err = net.Listen("tcp", addr)
	}

	if err != nil {
		logger.LogError(h.logger, HoneypotType, "listen_failed", err, []any{"addr", addr})
		return
	}
	defer ln.Close()

	go func() {
		<-ctx.Done()
		ln.Close()
	}()

	logger.LogInfo(h.logger, HoneypotType, "honeypot listening", []any{
		"port", port,
		"tls", useTLS,
	})

	for {
		conn, err := ln.Accept()
		if err != nil {
			if ctx.Err() != nil {
				return
			}
			continue
		}
		go h.handleSession(conn, useTLS)
	}
}

func (h *smtpHoneypot) handleSession(conn net.Conn, implicitTLS bool) {
	defer conn.Close()
	conn.SetDeadline(time.Now().Add(DefaultConnectionTimeout))

	rh, rp := utils.SplitAddr(conn.RemoteAddr().String(), h.logger)
	_, dp := utils.SplitAddr(conn.LocalAddr().String(), h.logger)

	session := &smtpSession{
		remoteHost: rh,
		remotePort: rp,
		dstPort:    dp,
		state:      stateNew,
		tlsEnabled: implicitTLS,
	}
	if implicitTLS {
		session.tlsType = "smtps"
	}

	var reader *bufio.Reader
	var writer *bufio.Writer

	if implicitTLS {
		reader = bufio.NewReader(conn)
		writer = bufio.NewWriter(conn)
	} else {
		// SMTP is server-speaks-first. Banner MUST be sent before reading anything.
		writer = bufio.NewWriter(conn)
		h.writeLine(writer, ServerBanner)
		reader = bufio.NewReader(conn)
	}

	// For implicit TLS, the banner comes after the handshake
	if implicitTLS {
		h.writeLine(writer, ServerBanner)
	}

	for {
		// Detect TLS handshake on non-TLS port
		if !session.tlsEnabled {
			peek, _ := reader.Peek(1)
			if len(peek) > 0 && peek[0] == 0x16 {
				h.logTLSHandshake(session)
				return
			}
		}

		line, err := reader.ReadString('\n')
		if err != nil {
			break
		}

		line = strings.TrimRight(line, "\r\n")
		session.commands = append(session.commands, line)

		parts := strings.Fields(line)
		if len(parts) == 0 {
			continue
		}

		cmd := strings.ToUpper(parts[0])
		switch cmd {
		case "HELO", "EHLO":
			session.state = stateGreeted
			if cmd == "HELO" {
				h.writeLine(writer, "250 localhost\r\n")
			} else {
				writer.WriteString("250-localhost\r\n")
				if h.tlsConfig != nil && !session.tlsEnabled {
					writer.WriteString("250-STARTTLS\r\n")
				}
				writer.WriteString("250 AUTH LOGIN PLAIN\r\n")
				writer.Flush()
			}

		case "STARTTLS":
			if session.tlsEnabled || h.tlsConfig == nil {
				h.writeLine(writer, "454 TLS not available\r\n")
				continue
			}
			h.writeLine(writer, "220 Ready to start TLS\r\n")
			tlsConn := tls.Server(conn, h.tlsConfig)
			if err := tlsConn.Handshake(); err != nil {
				break
			}
			conn = tlsConn
			reader = bufio.NewReader(conn)
			writer = bufio.NewWriter(conn)
			session.tlsEnabled = true
			session.tlsType = "starttls"
			session.state = stateNew

		case "AUTH":
			h.handleAUTH(reader, writer, session, parts)

		case "MAIL":
			if session.state < stateGreeted {
				h.writeLine(writer, "503 Bad sequence of commands\r\n")
				continue
			}
			session.mailFrom = h.extractEmail(line)
			session.state = stateMail
			h.writeLine(writer, "250 OK\r\n")

		case "RCPT":
			if session.state < stateMail {
				h.writeLine(writer, "503 Bad sequence of commands\r\n")
				continue
			}
			session.rcptTo = append(session.rcptTo, h.extractEmail(line))
			session.state = stateRcpt
			h.writeLine(writer, "250 OK\r\n")

		case "DATA":
			if session.state != stateRcpt {
				h.writeLine(writer, "503 Bad sequence of commands\r\n")
				continue
			}
			h.handleDATA(reader, writer, session)

		case "RSET":
			session.state = stateGreeted
			session.mailFrom, session.rcptTo, session.data = "", nil, nil
			h.writeLine(writer, "250 OK\r\n")

		case "NOOP":
			h.writeLine(writer, "250 OK\r\n")

		case "QUIT":
			h.writeLine(writer, "221 Bye\r\n")
			h.logSession(session)
			return

		default:
			h.writeLine(writer, "500 Command unrecognized\r\n")
		}
	}
	h.logSession(session)
}

func (h *smtpHoneypot) writeLine(w *bufio.Writer, s string) {
	w.WriteString(s)
	w.Flush()
}

func (h *smtpHoneypot) handleAUTH(r *bufio.Reader, w *bufio.Writer, s *smtpSession, parts []string) {
	if len(parts) < 2 {
		h.writeLine(w, "501 Syntax error\r\n")
		return
	}

	mech := strings.ToUpper(parts[1])
	var user, pass string

	switch mech {
	case "PLAIN":
		var encoded string
		if len(parts) >= 3 {
			encoded = parts[2]
		} else {
			h.writeLine(w, "334 \r\n")
			line, _ := r.ReadString('\n')
			encoded = strings.TrimSpace(line)
		}
		user, pass = h.decodeAuthPlain(encoded)

	case "LOGIN":
		h.writeLine(w, "334 VXNlcm5hbWU6\r\n") // Username:
		if line, err := r.ReadString('\n'); err == nil {
			s.commands = append(s.commands, strings.TrimRight(line, "\r\n"))
			uDec, _ := base64.StdEncoding.DecodeString(strings.TrimSpace(line))
			user = string(uDec)
		}

		h.writeLine(w, "334 UGFzc3dvcmQ6\r\n") // Password:
		if line, err := r.ReadString('\n'); err == nil {
			s.commands = append(s.commands, strings.TrimRight(line, "\r\n"))
			pDec, _ := base64.StdEncoding.DecodeString(strings.TrimSpace(line))
			pass = string(pDec)
		}

	default:
		h.writeLine(w, "504 Unrecognized authentication type\r\n")
		return
	}

	h.writeLine(w, "535 Authentication failed\r\n")
	h.logAuthAttempt(s, mech, user, pass)
}

func (h *smtpHoneypot) decodeAuthPlain(encoded string) (string, string) {
	decoded, err := base64.StdEncoding.DecodeString(encoded)
	if err != nil {
		return "", ""
	}
	parts := bytes.Split(decoded, []byte{0})
	if len(parts) != 3 {
		return "", ""
	}
	return string(parts[1]), string(parts[2])
}

func (h *smtpHoneypot) handleDATA(r *bufio.Reader, w *bufio.Writer, s *smtpSession) {
	h.writeLine(w, "354 End data with <CRLF>.<CRLF>\r\n")
	var buf bytes.Buffer
	for {
		l, err := r.ReadBytes('\n')
		if err != nil {
			break
		}
		if string(bytes.TrimSpace(l)) == "." {
			break
		}
		if buf.Len()+len(l) <= MaxDataSize {
			buf.Write(l)
		}
	}
	s.data = buf.Bytes()
	s.state = stateData
	h.writeLine(w, "250 OK\r\n")
}

func (h *smtpHoneypot) logAuthAttempt(s *smtpSession, mech, user, pass string) {
	fields := map[string]interface{}{
		"auth_mechanism": mech,
		"username":       user,
		"password":       pass,
	}
	if s.tlsEnabled {
		fields["tls"] = s.tlsType
	}
	logger.LogEvent(h.logger, types.LogEvent{
		Type:       HoneypotType,
		Event:      types.EventAuthAttempt,
		RemoteAddr: s.remoteHost,
		RemotePort: s.remotePort,
		DstPort:    s.dstPort,
		Fields:     fields,
	})
}

func (h *smtpHoneypot) logSession(s *smtpSession) {
	if len(s.commands) == 0 {
		return
	}

	fields := map[string]interface{}{
		"commands": s.commands,
	}
	if s.tlsEnabled {
		fields["tls"] = s.tlsType
	}
	if s.mailFrom != "" {
		fields["from"] = s.mailFrom
	}
	if len(s.rcptTo) > 0 {
		fields["to"] = s.rcptTo
	}
	if len(s.data) > 0 {
		fields["data_size"] = len(s.data)
		fields["data"] = string(s.data)
	}

	logger.LogEvent(h.logger, types.LogEvent{
		Type:       HoneypotType,
		Event:      types.EventRequest,
		RemoteAddr: s.remoteHost,
		RemotePort: s.remotePort,
		DstPort:    s.dstPort,
		Fields:     fields,
	})
}

func (h *smtpHoneypot) logTLSHandshake(s *smtpSession) {
	logger.LogEvent(h.logger, types.LogEvent{
		Type:       HoneypotType,
		Event:      types.EventTLSHandshake,
		RemoteAddr: s.remoteHost,
		RemotePort: s.remotePort,
		DstPort:    s.dstPort,
		Fields:     map[string]interface{}{"message": "TLS handshake attempt on non-TLS port"},
	})
}

var emailRegex = regexp.MustCompile(`<([^>]+)>`)

func (h *smtpHoneypot) extractEmail(line string) string {
	if m := emailRegex.FindStringSubmatch(line); len(m) > 1 {
		return m[1]
	}
	return ""
}

func (h *smtpHoneypot) GetScores(db *database.Database, interval string) honeypot.ScoreMap {
	// get scores for addresses with a set to address
	rows, err := db.DB.Query(fmt.Sprintf(`
	SELECT remote_addr, COUNT(*) as auth_count
	FROM honeypot_events
	WHERE type = 'smtp'
	AND fields.to IS NOT NULL
	AND time >= now() - INTERVAL %s
	GROUP BY remote_addr
	ORDER BY auth_count DESC
	`, interval))
	if err != nil {
		return honeypot.ScoreMap{}
	}
	defer rows.Close()

	scores := honeypot.ScoreMap{}
	for rows.Next() {
		var ip string
		var authCount uint
		err := rows.Scan(&ip, &authCount)
		if err != nil {
			return honeypot.ScoreMap{}
		}
		scores[ip] = honeypot.Score{Score: 100 * authCount, Tags: []types.Tag{types.TagAuthAttempt}}
	}
	return scores
}

func (h *smtpHoneypot) Ports() []uint16 {
	ports := append([]uint16{}, h.config.Ports...)
	ports = append(ports, h.config.SMTPSPorts...)
	return ports
}