internal/honeypot/rdp/rdp.go

package rdp

import (
	"bufio"
	"context"
	"crypto/tls"
	"encoding/binary"
	"fmt"
	"io"
	"log/slog"
	"net"
	"strings"
	"sync"
	"time"

	"honeypot/internal/database"
	"honeypot/internal/honeypot"
	"honeypot/internal/logger"
	"honeypot/internal/types"
	"honeypot/internal/utils"
)

const (
	HoneypotType             = types.HoneypotTypeRDP
	HoneypotLabel            = "RDP"
	DefaultConnectionTimeout = 3 * time.Minute
	RDPPort                  = 3389

	// TPKT constants
	TPKT_VERSION = 3

	// X.224 constants
	X224_TPDU_CONNECTION_REQUEST = 0xE0
	X224_TPDU_CONNECTION_CONFIRM = 0xD0

	// RDP Negotiation constants
	TYPE_RDP_NEG_REQ = 0x01
	TYPE_RDP_NEG_RSP = 0x02
	TYPE_RDP_NEG_ERR = 0x03

	PROTOCOL_RDP       = 0x00000000
	PROTOCOL_SSL       = 0x00000001
	PROTOCOL_HYBRID    = 0x00000002
	PROTOCOL_HYBRID_EX = 0x00000008

	// MCS constants
	MCS_CONNECT_INITIAL      = 0x7f
	MCS_CONNECT_RESPONSE     = 0x70
	MCS_ERECT_DOMAIN         = 0x04
	MCS_ATTACH_USER_REQUEST  = 0x28
	MCS_ATTACH_USER_CONFIRM  = 0x2c
	MCS_CHANNEL_JOIN_REQUEST = 0x38
	MCS_CHANNEL_JOIN_CONFIRM = 0x3c

	// RDP Data PDU types
	PDUTYPE_CLIENT_INFO = 0x02
)

// Config holds the configuration for the RDP honeypot.
type Config struct {
	ListenAddr  string
	Ports       []uint16
	Certificate *tls.Certificate
}

// rdpHoneypot implements the honeypot.Honeypot interface.
type rdpHoneypot struct {
	config Config
	logger *slog.Logger
}

// New creates a new RDP honeypot instance.
func New(cfg Config) honeypot.Honeypot {
	// Register username field for top-N tracking
	logger.RegisterTopNField("rdp", "username")

	return &rdpHoneypot{
		config: cfg,
	}
}

// Name returns the name of this honeypot.
func (h *rdpHoneypot) Name() types.HoneypotType {
	return HoneypotType
}

// Label returns the label of this honeypot.
func (h *rdpHoneypot) Label() string {
	return HoneypotLabel
}

// Start starts the RDP honeypot server.
func (h *rdpHoneypot) Start(ctx context.Context, l *slog.Logger) error {
	h.logger = l

	var wg sync.WaitGroup

	for _, port := range h.config.Ports {
		if port == 0 {
			continue
		}

		wg.Add(1)
		go func(p uint16) {
			defer wg.Done()
			h.startRDPServer(ctx, p)
		}(port)
	}

	wg.Wait()
	logger.LogInfo(h.logger, HoneypotType, "honeypot shutdown complete", nil)
	return nil
}

func (h *rdpHoneypot) startRDPServer(ctx context.Context, port uint16) {
	listenAddr := utils.BuildAddress(h.config.ListenAddr, port)
	listener, err := net.Listen("tcp", listenAddr)
	if err != nil {
		logger.LogError(h.logger, HoneypotType, "listen_failed", err, []any{
			"addr", listenAddr,
		})
		return
	}
	defer listener.Close()

	logger.LogInfo(h.logger, HoneypotType, "honeypot listening", []any{
		"port", port,
	})

	go func() {
		<-ctx.Done()
		listener.Close()
	}()

	for {
		conn, err := listener.Accept()
		if err != nil {
			if ctx.Err() != nil {
				break
			}
			logger.LogError(h.logger, HoneypotType, "accept_failed", err, nil)
			continue
		}

		go h.handleConn(conn)
	}
}

func (h *rdpHoneypot) handleConn(conn net.Conn) {
	defer conn.Close()
	conn.SetDeadline(time.Now().Add(DefaultConnectionTimeout))

	remoteHost, remotePort := utils.SplitAddr(conn.RemoteAddr().String(), h.logger)
	_, dstPort := utils.SplitAddr(conn.LocalAddr().String(), h.logger)

	eventLogged := false
	defer func() {
		if !eventLogged {
			h.logEvent(remoteHost, remotePort, dstPort, types.EventTCPPacket, map[string]interface{}{
				"message": "connection closed without rdp events",
			})
		}
	}()

	reader := bufio.NewReader(conn)

	// Peek at the first 5 bytes to determine the protocol
	peeked, err := reader.Peek(5)
	if err != nil && err != io.EOF {
		return
	}

	if len(peeked) > 0 {
		// Detect TLS Handshake (0x16 ...)
		if peeked[0] == 0x16 {
			eventLogged = true
			h.logEvent(remoteHost, remotePort, dstPort, types.EventTLSHandshake, map[string]interface{}{
				"message": "TLS handshake detected on RDP port",
			})
			return
		}

		// Check if it looks like TPKT (0x03)
		if peeked[0] != TPKT_VERSION {
			eventLogged = true
			h.logEvent(remoteHost, remotePort, dstPort, types.EventTCPPacket, map[string]interface{}{
				"message": "non-rdp protocol detected",
				"payload": fmt.Sprintf("%x", peeked),
			})
			return
		}
	}

	// 1. Read TPKT Header (4 bytes)
	tpktHeader := make([]byte, 4)
	if _, err := io.ReadFull(reader, tpktHeader); err != nil {
		eventLogged = true
		h.logEvent(remoteHost, remotePort, dstPort, types.EventTCPPacket, map[string]interface{}{
			"message": "incomplete tpkt header",
			"error":   err.Error(),
		})
		return
	}

	if tpktHeader[0] != TPKT_VERSION {
		return
	}

	length := binary.BigEndian.Uint16(tpktHeader[2:])
	if length < 4 {
		eventLogged = true
		h.logEvent(remoteHost, remotePort, dstPort, types.EventTCPPacket, map[string]interface{}{
			"message": "invalid tpkt length",
			"length":  length,
		})
		return
	}

	// 2. Read X.224 Connection Request
	payloadLen := int(length) - 4
	payload := make([]byte, payloadLen)
	if _, err := io.ReadFull(reader, payload); err != nil {
		eventLogged = true
		h.logEvent(remoteHost, remotePort, dstPort, types.EventTCPPacket, map[string]interface{}{
			"message": "incomplete tpkt payload",
			"error":   err.Error(),
			"header":  fmt.Sprintf("%x", tpktHeader),
		})
		return
	}

	if len(payload) < 3 {
		eventLogged = true
		h.logEvent(remoteHost, remotePort, dstPort, types.EventTCPPacket, map[string]interface{}{
			"message": "tpkt payload too small for x224",
			"payload": fmt.Sprintf("%x", payload),
		})
		return
	}

	// payload[0] is length of X.224 header (excluding itself)
	// payload[1] is PDU type
	pduType := payload[1]
	if pduType != X224_TPDU_CONNECTION_REQUEST {
		eventLogged = true
		h.logEvent(remoteHost, remotePort, dstPort, types.EventTCPPacket, map[string]interface{}{
			"message":  "unexpected x224 pdu type",
			"pdu_type": fmt.Sprintf("0x%02x", pduType),
			"payload":  fmt.Sprintf("%x", payload),
		})
		return
	}

	// Extract username from cookie if present
	// Cookie format: "Cookie: mstshash=USERNAME\r\n"
	cookieStr := ""
	payloadStr := string(payload)
	if idx := strings.Index(payloadStr, "Cookie: mstshash="); idx != -1 {
		cookieEnd := strings.Index(payloadStr[idx:], "\r\n")
		if cookieEnd != -1 {
			cookieStr = payloadStr[idx+len("Cookie: mstshash:") : idx+cookieEnd]
			cookieStr = strings.TrimPrefix(cookieStr, "mstshash=")
		}
	}

	// Look for RDP Negotiation Request (optional, at the end of CR)
	// Negotiation Request is typically after the cookie
	var requestedProtocols uint32
	hasNegReq := false

	// Scan for the negotiation request signature (type=1, length=8)
	for i := 0; i < len(payload)-7; i++ {
		if payload[i] == TYPE_RDP_NEG_REQ && payload[i+1] == 0x00 && payload[i+2] == 0x08 && payload[i+3] == 0x00 {
			requestedProtocols = binary.LittleEndian.Uint32(payload[i+4 : i+8])
			hasNegReq = true
			break
		}
	}

	// 3. Send Connection Confirm (CC)
	selectedProtocol := uint32(PROTOCOL_RDP)
	if hasNegReq && h.config.Certificate != nil {
		if requestedProtocols&PROTOCOL_HYBRID != 0 || requestedProtocols&PROTOCOL_HYBRID_EX != 0 {
			selectedProtocol = PROTOCOL_HYBRID
		} else if requestedProtocols&PROTOCOL_SSL != 0 {
			selectedProtocol = PROTOCOL_SSL
		}
	}

	// TPKT + X.224 CC + RDP Negotiation Response
	resp := []byte{
		0x03, 0x00, 0x00, 0x13, // TPKT: v3, res0, len 19
		0x0e,       // X.224: len 14
		0xd0,       // X.224: CC
		0x00, 0x00, // DST-REF
		0x12, 0x34, // SRC-REF
		0x00,                   // Class
		0x02, 0x00, 0x08, 0x00, // RDP Neg Response: Type 2, Len 8
		0x00, 0x00, 0x00, 0x00, // Selected Protocol (to be filled)
	}
	binary.LittleEndian.PutUint32(resp[15:], selectedProtocol)

	if _, err := conn.Write(resp); err != nil {
		return
	}

	// 4. Upgrade to TLS if selected
	if selectedProtocol == PROTOCOL_SSL || selectedProtocol == PROTOCOL_HYBRID {
		h.logger.Debug("upgrading to TLS", "protocol", selectedProtocol)
		tlsConfig := &tls.Config{
			Certificates: []tls.Certificate{*h.config.Certificate},
			MinVersion:   tls.VersionTLS12,
		}
		tlsConn := tls.Server(conn, tlsConfig)
		if err := tlsConn.Handshake(); err != nil {
			h.logger.Debug("TLS handshake failed", "error", err)
			h.logAuthAttempt(remoteHost, remotePort, dstPort, cookieStr, "", map[string]interface{}{
				"requested_protocols": requestedProtocols,
				"has_neg_req":         hasNegReq,
				"error":               "tls_handshake_failed",
			})
			return
		}

		h.logger.Debug("TLS handshake successful")
		// Continue with TLS-wrapped connection
		h.handleTLSConn(tlsConn, remoteHost, remotePort, dstPort, cookieStr, selectedProtocol)
	} else {
		// Log basic RDP attempt if no TLS upgrade
		eventLogged = true
		h.logAuthAttempt(remoteHost, remotePort, dstPort, cookieStr, "", map[string]interface{}{
			"requested_protocols": requestedProtocols,
			"has_neg_req":         hasNegReq,
			"security_layer":      "rdp",
		})
		h.handleStandardRDP(conn, reader, remoteHost, remotePort, dstPort)
	}
}

func (h *rdpHoneypot) handleStandardRDP(conn net.Conn, reader *bufio.Reader, remoteHost string, remotePort, dstPort uint16) {
	h.logger.Debug("entering handleStandardRDP")

	// 1. Read MCS Connect Initial
	tpktHeader := make([]byte, 4)
	if _, err := io.ReadFull(reader, tpktHeader); err != nil {
		return
	}
	length := binary.BigEndian.Uint16(tpktHeader[2:])
	payload := make([]byte, int(length)-4)
	if _, err := io.ReadFull(reader, payload); err != nil {
		return
	}

	h.logger.Debug("received MCS Connect Initial, aborting")
}

func (h *rdpHoneypot) handleTLSConn(conn net.Conn, remoteHost string, remotePort, dstPort uint16, username string, selectedProtocol uint32) {
	h.logger.Debug("entering handleTLSConn", "username", username)
	reader := bufio.NewReader(conn)

	// If the selected protocol is just TLS (PROTOCOL_SSL), we log the attempt
	// based on the cookie username and abort before the MCS handshake.
	if selectedProtocol == PROTOCOL_SSL {
		h.logAuthAttempt(remoteHost, remotePort, dstPort, username, "", map[string]interface{}{
			"security_layer": "tls",
		})
		return
	}

	for {
		// Peek at the first byte to determine if it's TPKT (0x03) or ASN.1 (0x30)
		firstByte, err := reader.Peek(1)
		if err != nil {
			h.logger.Debug("failed to peek byte in TLS", "error", err)
			return
		}

		var payload []byte

		switch firstByte[0] {
		case TPKT_VERSION:
			// 1. Read TPKT Header (4 bytes)
			tpktHeader := make([]byte, 4)
			if _, err := io.ReadFull(reader, tpktHeader); err != nil {
				h.logger.Debug("failed to read TPKT header in TLS", "error", err)
				return
			}

			length := binary.BigEndian.Uint16(tpktHeader[2:])
			h.logger.Debug("received TPKT in TLS", "length", length)
			if length < 7 { // 4 (TPKT) + 3 (X.224)
				return
			}

			// 2. Read X.224 Header (3 bytes)
			x224Header := make([]byte, 3)
			if _, err := io.ReadFull(reader, x224Header); err != nil {
				h.logger.Debug("failed to read X.224 header in TLS", "error", err)
				return
			}

			// 3. Read TSRequest Payload
			payload = make([]byte, int(length)-7)
			if _, err := io.ReadFull(reader, payload); err != nil {
				h.logger.Debug("failed to read TSRequest payload", "error", err)
				return
			}
		case 0x30:
			// Direct ASN.1 SEQUENCE (CredSSP TSRequest)
			// Read the ASN.1 object
			payload, err = h.readASN1Object(reader)
			if err != nil {
				h.logger.Debug("failed to read ASN.1 object", "error", err)
				return
			}
		default:
			h.logger.Debug("unknown protocol byte in TLS", "byte", firstByte[0])
			return
		}

		ts, err := parseTSRequest(payload)
		if err != nil {
			h.logger.Debug("failed to parse TSRequest", "error", err, "payload_len", len(payload))
			// Not a TSRequest, could be MCS or something else
			continue
		}

		h.logger.Debug("parsed TSRequest", "nego_tokens", len(ts.NegoTokens))

		if len(ts.NegoTokens) == 0 {
			continue
		}

		token := ts.NegoTokens[0].NegoToken
		if len(token) < 8 || string(token[:8]) != NTLMSSP_SIGNATURE {
			h.logger.Debug("invalid NTLMSSP signature", "token_len", len(token))
			continue
		}

		msgType := binary.LittleEndian.Uint32(token[8:12])
		h.logger.Debug("received NTLMSSP message", "type", msgType)

		switch msgType {
		case NTLM_TYPE1: // Negotiate
			clientFlags := uint32(0)
			if len(token) >= 16 {
				clientFlags = binary.LittleEndian.Uint32(token[12:16])
			}
			h.logger.Debug("handling NTLM Negotiate", "client_flags", fmt.Sprintf("0x%08x", clientFlags))

			// Send Type 2 Challenge with a dummy target name and info
			targetName := "WORKGROUP"
			targetNameBytes := []byte{0x57, 0x00, 0x4f, 0x00, 0x52, 0x00, 0x4b, 0x00, 0x47, 0x00, 0x52, 0x00, 0x4f, 0x00, 0x55, 0x00, 0x50, 0x00} // "WORKGROUP" in UTF-16LE

			// Target Info: NB Computer Name (WORKGROUP) + End of List
			targetInfoBytes := []byte{
				0x03, 0x00, 0x12, 0x00, // Type: MsvAvNbComputerName (3), Len: 18
			}
			targetInfoBytes = append(targetInfoBytes, targetNameBytes...)
			targetInfoBytes = append(targetInfoBytes, 0x00, 0x00, 0x00, 0x00) // Type: MsvAvEOL (0), Len: 0

			challenge := make([]byte, 56+len(targetNameBytes)+len(targetInfoBytes))
			copy(challenge[0:8], NTLMSSP_SIGNATURE)
			binary.LittleEndian.PutUint32(challenge[8:12], 2) // Type 2

			// Target Name Info (Len, MaxLen, Offset)
			targetNameLen := len(targetNameBytes)
			binary.LittleEndian.PutUint16(challenge[12:14], uint16(targetNameLen))
			binary.LittleEndian.PutUint16(challenge[14:16], uint16(targetNameLen))
			binary.LittleEndian.PutUint32(challenge[16:20], 56)

			// Flags: Unicode (1), Target (4), NTLM (200), Target Info (8000), Extended Session Security (80000), 128-bit (20000000), 56-bit (80000000)
			binary.LittleEndian.PutUint32(challenge[20:24], 0xa0088205)

			// Challenge
			copy(challenge[24:32], []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef})

			// Reserved (8 bytes)

			// Target Info Info (Len, MaxLen, Offset)
			targetInfoLen := len(targetInfoBytes)
			binary.LittleEndian.PutUint16(challenge[40:42], uint16(targetInfoLen))
			binary.LittleEndian.PutUint16(challenge[42:44], uint16(targetInfoLen))
			binary.LittleEndian.PutUint32(challenge[44:48], uint32(56+targetNameLen))

			// Version: Windows 10.0 (Build 19041)
			copy(challenge[48:56], []byte{0x0a, 0x00, 0x61, 0x4a, 0x00, 0x00, 0x00, 0x0f})

			// Append Payload
			copy(challenge[56:], targetNameBytes)
			copy(challenge[56+len(targetNameBytes):], targetInfoBytes)

			respTs, err := buildTSResponse(challenge)
			if err != nil {
				h.logger.Error("failed to build TSResponse", "error", err)
				return
			}

			// Wrap in TPKT/X.224 only if we received TPKT
			var resp []byte
			if firstByte[0] == TPKT_VERSION {
				resp = make([]byte, 4+3+len(respTs))
				resp[0] = TPKT_VERSION
				respLen := len(resp)
				binary.BigEndian.PutUint16(resp[2:], uint16(respLen))
				resp[4] = 0x02 // X.224 LI
				resp[5] = 0xf0 // X.224 Data TPDU
				resp[6] = 0x80 // EOT
				copy(resp[7:], respTs)
			} else {
				// Send raw ASN.1 response
				resp = respTs
			}

			if _, err := conn.Write(resp); err != nil {
				h.logger.Error("failed to write NTLM challenge", "error", err)
				return
			}
			h.logger.Debug("sent NTLM challenge", "target", targetName)

		case NTLM_TYPE3: // Authenticate
			user, dom, work, hash := parseNTLMType3(token)
			h.logger.Debug("handling NTLM Authenticate", "user", user, "domain", dom)
			h.logAuthAttempt(remoteHost, remotePort, dstPort, user, "", map[string]interface{}{
				"domain":         dom,
				"workstation":    work,
				"ntlm_hash":      hash,
				"security_layer": "nla",
			})
			return // Done with this connection
		}
	}
}

// readASN1Object reads a single ASN.1 object from the reader
func (h *rdpHoneypot) readASN1Object(reader *bufio.Reader) ([]byte, error) {
	// 1. Read Tag
	tag, err := reader.ReadByte()
	if err != nil {
		return nil, err
	}

	// 2. Read Length
	lenByte, err := reader.ReadByte()
	if err != nil {
		return nil, err
	}

	var length int
	var lengthBytes []byte

	if lenByte < 128 {
		length = int(lenByte)
		lengthBytes = []byte{lenByte}
	} else {
		numLenBytes := int(lenByte & 0x7f)
		lengthBytes = make([]byte, numLenBytes)
		if _, err := io.ReadFull(reader, lengthBytes); err != nil {
			return nil, err
		}

		for _, b := range lengthBytes {
			length = (length << 8) | int(b)
		}
		lengthBytes = append([]byte{lenByte}, lengthBytes...)
	}

	// 3. Read Content
	payload := make([]byte, length)
	if _, err := io.ReadFull(reader, payload); err != nil {
		return nil, err
	}

	// Reconstruct the full ASN.1 object
	res := make([]byte, 0, 1+len(lengthBytes)+length)
	res = append(res, tag)
	res = append(res, lengthBytes...)
	res = append(res, payload...)

	return res, nil
}

func (h *rdpHoneypot) logAuthAttempt(remoteHost string, remotePort uint16, dstPort uint16, username, password string, fields map[string]interface{}) {
	if fields == nil {
		fields = make(map[string]interface{})
	}
	fields["username"] = username
	if password != "" {
		fields["password"] = password
	}

	h.logEvent(remoteHost, remotePort, dstPort, types.EventAuthAttempt, fields)
}

func (h *rdpHoneypot) logEvent(remoteHost string, remotePort uint16, dstPort uint16, event types.HoneypotEvent, fields map[string]interface{}) {
	logger.LogEvent(h.logger, types.LogEvent{
		Type:       HoneypotType,
		Event:      event,
		RemoteAddr: remoteHost,
		RemotePort: remotePort,
		DstPort:    dstPort,
		Fields:     fields,
	})
}

func (h *rdpHoneypot) GetScores(db *database.Database, interval string) honeypot.ScoreMap {
	// get score for tcp_packets on rdp type
	rows, err := db.DB.Query(fmt.Sprintf(`
	SELECT remote_addr, COUNT(*) as tcp_connection
	FROM honeypot_events
	WHERE type = 'rdp'
	AND event = 'tcp_packet'
	AND time >= now() - INTERVAL %s
	GROUP BY remote_addr`, interval))
	if err != nil {
		return honeypot.ScoreMap{}
	}
	defer rows.Close()

	scoreMap := honeypot.ScoreMap{}
	for rows.Next() {
		var remoteAddr string
		var tcpConnection uint
		err := rows.Scan(&remoteAddr, &tcpConnection)
		if err != nil {
			return honeypot.ScoreMap{}
		}
		scoreMap[remoteAddr] = honeypot.Score{Score: tcpConnection * 100, Tags: []types.Tag{types.TagAuthAttempt}}
	}

	return scoreMap
}

func (h *rdpHoneypot) Ports() []uint16 {
	return h.config.Ports
}