internal/honeypot/packetlogger/packetlogger.go

package packetlogger

import (
	"context"
	"fmt"
	"log/slog"
	"net"
	"time"

	"honeypot/internal/database"
	"honeypot/internal/honeypot"
	"honeypot/internal/logger"
	"honeypot/internal/types"
	"honeypot/internal/utils"

	"github.com/google/gopacket"
	"github.com/google/gopacket/layers"
	"github.com/google/gopacket/pcap"
)

const (
	HoneypotType       = types.HoneypotTypePacketLogger
	HoneypotLabel      = "Packet"
	DefaultSnapshotLen = 1600
	ShutdownTimeout    = 1 * time.Second
)

// Config holds the configuration for the packet logger.
type Config struct {
	Interface     string
	BpfExpression string
}

// packetLoggerHoneypot implements the honeypot.Honeypot interface.
type packetLoggerHoneypot struct {
	config Config
	logger *slog.Logger
}

// New creates a new packet logger honeypot instance.
func New(cfg Config) honeypot.Honeypot {
	return &packetLoggerHoneypot{
		config: cfg,
	}
}

// Name returns the name of this honeypot.
func (h *packetLoggerHoneypot) Name() types.HoneypotType {
	return HoneypotType
}

// Label returns the label of this honeypot.
func (h *packetLoggerHoneypot) Label() string {
	return HoneypotLabel
}

// Start starts the packet logger.
func (h *packetLoggerHoneypot) Start(ctx context.Context, l *slog.Logger) error {
	h.logger = l

	if err := h.validateInterface(); err != nil {
		return err
	}

	ifaceIP, err := h.getInterfaceIP()
	if err != nil {
		return err
	}

	handle, err := h.openPacketCapture()
	if err != nil {
		return err
	}
	defer handle.Close()

	bpfFilter := h.buildBPFFilter(ifaceIP)
	if err := handle.SetBPFFilter(bpfFilter); err != nil {
		logger.LogError(h.logger, HoneypotType, "set_bpf_filter_failed", err, []any{
			"bpf_expression", bpfFilter,
		})
		return err
	}

	logger.LogInfo(h.logger, HoneypotType, "honeypot listening", []any{
		"interface", h.config.Interface,
		"bpf_expression", bpfFilter,
	})

	h.setupGracefulShutdown(ctx, handle)
	h.processPackets(ctx, handle)

	logger.LogInfo(h.logger, HoneypotType, "honeypot shutdown complete", nil)
	return nil
}

// validateInterface checks if the configured interface exists.
func (h *packetLoggerHoneypot) validateInterface() error {
	devices, err := pcap.FindAllDevs()
	if err != nil {
		logger.LogError(h.logger, HoneypotType, "find_devices_failed", err, nil)
		return err
	}

	for _, d := range devices {
		if d.Name == h.config.Interface {
			return nil
		}
	}

	err = fmt.Errorf("interface %s not found", h.config.Interface)
	logger.LogError(h.logger, HoneypotType, "interface_not_found", err, []any{
		"interface", h.config.Interface,
	})
	return err
}

// getInterfaceIP extracts the IPv4 address from the configured interface.
func (h *packetLoggerHoneypot) getInterfaceIP() (net.IP, error) {
	iface, err := net.InterfaceByName(h.config.Interface)
	if err != nil {
		logger.LogError(h.logger, HoneypotType, "get_interface_failed", err, []any{
			"interface", h.config.Interface,
		})
		return nil, err
	}

	addrs, err := iface.Addrs()
	if err != nil {
		logger.LogError(h.logger, HoneypotType, "get_interface_addrs_failed", err, []any{
			"interface", iface.Name,
		})
		return nil, err
	}

	for _, addr := range addrs {
		if ipNet, ok := addr.(*net.IPNet); ok && ipNet.IP.To4() != nil {
			return ipNet.IP, nil
		}
	}

	err = fmt.Errorf("no IPv4 address found on interface %s", iface.Name)
	logger.LogError(h.logger, HoneypotType, "get_iface_ip_failed", err, []any{
		"interface", iface.Name,
	})
	return nil, err
}

// openPacketCapture opens a live packet capture handle.
func (h *packetLoggerHoneypot) openPacketCapture() (*pcap.Handle, error) {
	handle, err := pcap.OpenLive(h.config.Interface, DefaultSnapshotLen, true, pcap.BlockForever)
	if err != nil {
		logger.LogError(h.logger, HoneypotType, "open_live_failed", err, []any{
			"interface", h.config.Interface,
		})
		return nil, err
	}
	return handle, nil
}

// buildBPFFilter constructs the BPF filter expression.
func (h *packetLoggerHoneypot) buildBPFFilter(ifaceIP net.IP) string {
	bpfSrcHost := fmt.Sprintf("not src host %s", ifaceIP.String())
	bpfSyn := "tcp[tcpflags] & tcp-syn != 0 or udp or icmp[0] == 8"
	bpfDNSResponse := "not src port 53" // exclude DNS responses from dns lookups

	bpf := fmt.Sprintf("(%s) and (%s) and (%s)", bpfSrcHost, bpfSyn, bpfDNSResponse)

	if h.config.BpfExpression != "" {
		bpf = fmt.Sprintf("(%s) and (%s)", bpf, h.config.BpfExpression)
	}
	return bpf
}

// setupGracefulShutdown handles graceful shutdown on context cancellation.
func (h *packetLoggerHoneypot) setupGracefulShutdown(ctx context.Context, handle *pcap.Handle) {
	go func() {
		<-ctx.Done()
		handle.Close()
	}()
}

// processPackets processes incoming packets and logs relevant events.
func (h *packetLoggerHoneypot) processPackets(ctx context.Context, handle *pcap.Handle) {
	packetSource := gopacket.NewPacketSource(handle, handle.LinkType())

	for packet := range packetSource.Packets() {
		if ctx.Err() != nil {
			break
		}
		h.handlePacket(packet)
	}
}

// handlePacket processes a single packet and logs relevant events.
func (h *packetLoggerHoneypot) handlePacket(packet gopacket.Packet) {
	ipLayer := packet.Layer(layers.LayerTypeIPv4)
	if ipLayer == nil {
		return
	}

	ip, ok := ipLayer.(*layers.IPv4)
	if !ok {
		return
	}

	if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
		if tcp, ok := tcpLayer.(*layers.TCP); ok {
			h.handleTCPPacket(ip, tcp)
		}
	}

	if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil {
		if udp, ok := udpLayer.(*layers.UDP); ok {
			h.handleUDPPacket(ip, udp)
		}
	}

	if icmpLayer := packet.Layer(layers.LayerTypeICMPv4); icmpLayer != nil {
		if icmp, ok := icmpLayer.(*layers.ICMPv4); ok {
			h.handleICMPPacket(ip, icmp)
		}
	}
}

// handleTCPPacket processes TCP packets and logs SYN packets.
func (h *packetLoggerHoneypot) handleTCPPacket(ip *layers.IPv4, tcp *layers.TCP) {
	if tcp.SYN && !tcp.ACK {

		remotePort := utils.SanitizePort(tcp.SrcPort.String(), h.logger)
		dstPort := utils.SanitizePort(tcp.DstPort.String(), h.logger)

		logger.LogEvent(h.logger, types.LogEvent{
			Type:       HoneypotType,
			Event:      types.EventTCPPacket,
			RemoteAddr: ip.SrcIP.String(),
			RemotePort: remotePort,
			DstPort:    dstPort,
		})
	}
}

// handleUDPPacket processes UDP packets and logs UDP packets.
func (h *packetLoggerHoneypot) handleUDPPacket(ip *layers.IPv4, udp *layers.UDP) {
	remotePort := utils.SanitizePort(udp.SrcPort.String(), h.logger)
	dstPort := utils.SanitizePort(udp.DstPort.String(), h.logger)

	logger.LogEvent(h.logger, types.LogEvent{
		Type:       HoneypotType,
		Event:      types.EventUDPPacket,
		RemoteAddr: ip.SrcIP.String(),
		RemotePort: remotePort,
		DstPort:    dstPort,
	})
}

// handleICMPPacket processes ICMP packets and logs echo requests.
func (h *packetLoggerHoneypot) handleICMPPacket(ip *layers.IPv4, icmp *layers.ICMPv4) {
	if icmp.TypeCode.Type() == 8 { // 8 is echo request
		logger.LogEvent(h.logger, types.LogEvent{
			Type:       HoneypotType,
			Event:      types.EventICMPPacket,
			RemoteAddr: ip.SrcIP.String(),
		})
	}
}

func (h *packetLoggerHoneypot) GetScores(db *database.Database, interval string) honeypot.ScoreMap {
	// get scores for addresses with more than 10 packets on more than 5 different ports in 12 hours
	query := `
	SELECT remote_addr, COUNT(DISTINCT dst_port) as port_count, COUNT(*) AS packet_count
	FROM honeypot_events
	WHERE type = 'packetlogger' AND time >= now() - INTERVAL '12 HOURS'
	GROUP BY remote_addr HAVING COUNT(DISTINCT dst_port) >= 5 AND COUNT(*) >= 10
	`

	rows, err := db.DB.Query(query)
	if err != nil {
		return honeypot.ScoreMap{}
	}
	defer rows.Close()

	portScanScores := honeypot.ScoreMap{}
	for rows.Next() {
		var ip string
		var portCount uint
		var packetCount uint
		err := rows.Scan(&ip, &portCount, &packetCount)
		if err != nil {
			return honeypot.ScoreMap{}
		}
		portScanScores[ip] = honeypot.Score{Score: portCount * 50, Tags: []types.Tag{types.TagPortScan}}
	}

	//get scores for addresses with 9 or more pings on more than 1 minute
	const countThreshold = 9
	const bucketSize = "60 seconds"

	query = fmt.Sprintf(`
	WITH raw_counts AS (
        SELECT 
            time,
            remote_addr,
            COUNT(*) OVER (
                ORDER BY time 
                RANGE BETWEEN INTERVAL %s PRECEDING AND CURRENT ROW
            ) as rolling_count
        FROM honeypot_events
        WHERE event = 'icmp_packet' AND time >= now() - INTERVAL %s
    ),
    peaks AS (
        SELECT 
            time,
            remote_addr,
            rolling_count,
            CASE 
                WHEN time - LAG(time) OVER (ORDER BY time) <= INTERVAL '10 seconds' 
                THEN 0 
                ELSE 1 
            END AS is_new_event
        FROM raw_counts
        WHERE rolling_count >= %d
    ),
    peak_groups AS (
        SELECT 
            time,
            remote_addr,
            SUM(is_new_event) OVER (ORDER BY time) AS group_id
        FROM peaks
    ),
    scan_windows AS (
        SELECT
            group_id,
            MIN(time) AS scan_start,
            MAX(time) AS scan_end,
            COUNT(*) AS peak_packets_count,
            list(DISTINCT remote_addr) AS unique_ips
        FROM peak_groups
        GROUP BY group_id
    )
    SELECT 
        scan_start,
        scan_end,
        (
            SELECT COUNT(*) 
            FROM honeypot_events 
            WHERE event = 'icmp_packet' 
              AND time >= s.scan_start - INTERVAL %s
              AND time <= s.scan_end
        ) AS total_involved_packets,
        unique_ips
    FROM scan_windows s
    ORDER BY scan_start DESC
	`, interval, bucketSize, countThreshold, bucketSize)

	rows, err = db.DB.Query(query)
	if err != nil {
		return honeypot.ScoreMap{}
	}
	defer rows.Close()

	pingScores := []honeypot.ScoreMap{}
	for rows.Next() {
		var scanStart time.Time
		var scanEnd time.Time
		var totalInvolvedPackets uint
		var uniqueIps []interface{}
		err := rows.Scan(&scanStart, &scanEnd, &totalInvolvedPackets, &uniqueIps)
		if err != nil {
			return honeypot.ScoreMap{}
		}

		for _, ip := range uniqueIps {
			pingScores = append(pingScores, honeypot.ScoreMap{
				ip.(string): honeypot.Score{Score: totalInvolvedPackets * 50, Tags: []types.Tag{types.TagPingScan}},
			})
		}
	}

	pingScoresMap := honeypot.MergeScores(pingScores...)

	// get high traffic ips
	query = fmt.Sprintf(`
	SELECT remote_addr, COUNT(*) as packet_count
	FROM honeypot_events
	WHERE type = 'packetlogger' AND time >= now() - INTERVAL %s
	GROUP BY remote_addr HAVING COUNT(*) >= 200
	ORDER BY packet_count DESC
	`, interval)
	rows, err = db.DB.Query(query)
	if err != nil {
		return honeypot.ScoreMap{}
	}
	defer rows.Close()

	highTrafficScores := honeypot.ScoreMap{}
	for rows.Next() {
		var ip string
		var packetCount uint
		err := rows.Scan(&ip, &packetCount)
		if err != nil {
			return honeypot.ScoreMap{}
		}
		highTrafficScores[ip] = honeypot.Score{Score: packetCount, Tags: []types.Tag{types.TagHighTraffic}}
	}

	// get botnet IPs by subnet (/18, /20, /22, /24) if there are more than 50 events
	// in the subnet and more than 5 ports used in the subnet
	query = fmt.Sprintf(`
WITH filtered_ips AS (
    SELECT
        remote_ip_int,
        COUNT(*) AS ip_count,
		COUNT(DISTINCT dst_port) AS port_count
    FROM honeypot_events
    WHERE type = 'packetlogger'
      AND time >= now() - INTERVAL %s
    GROUP BY remote_ip_int
),
subnet_stats AS (
    SELECT
        mask,
        (remote_ip_int >> (32 - mask)) AS subnet,
        COUNT(*) AS addr_count,
        SUM(ip_count) AS event_count,
        SUM(port_count) AS port_count
    FROM filtered_ips
    CROSS JOIN (VALUES (18), (20), (22), (24)) AS masks(mask)
    GROUP BY mask, subnet
    HAVING COUNT(*) >= 5 AND SUM(ip_count) * SUM(port_count) >= 50
)
SELECT
    f.remote_ip_int,
    s.mask,
    s.event_count,
	s.port_count
FROM filtered_ips f
JOIN subnet_stats s
  ON (f.remote_ip_int >> (32 - s.mask)) = s.subnet
`, interval)

	rows, err = db.DB.Query(query)
	if err != nil {
		return honeypot.ScoreMap{}
	}
	defer rows.Close()

	type AddressBotnetScore struct {
		ipInt       uint32
		packetCount uint
		portCount   uint
		mask        int
	}

	scores := []AddressBotnetScore{}

	for rows.Next() {
		var ipInt uint32
		var mask int
		var packetCount uint
		var portCount uint

		if err := rows.Scan(&ipInt, &mask, &packetCount, &portCount); err != nil {
			return honeypot.ScoreMap{}
		}

		scores = append(scores, AddressBotnetScore{
			ipInt:       ipInt,
			mask:        mask,
			packetCount: packetCount,
			portCount:   portCount,
		})
	}

	temp := honeypot.ScoreMap{}
	for _, s := range scores {
		ip := utils.IntToIP(s.ipInt)

		subnet, err := maskIP(ip, s.mask)
		if err != nil {
			continue
		}

		temp[subnet] = honeypot.Score{
			Score: s.packetCount * s.portCount,
			Tags:  []types.Tag{types.TagBotnet},
		}

	}

	botnetScores := filterSubnets(temp)

	return honeypot.MergeScores(
		portScanScores,
		pingScoresMap,
		highTrafficScores,
		botnetScores,
	)
}

// filterSubnets removes subnets that contain other subnets in the map,
// keeping only the most specific (smallest) subnets.
func filterSubnets(scores honeypot.ScoreMap) honeypot.ScoreMap {
	result := honeypot.ScoreMap{}
	for subnet, score := range scores {
		isGeneral := false
		for existingSubnet := range result {
			if NetIncludesNet(subnet, existingSubnet) {
				// The new subnet is more general than an existing one.
				isGeneral = true
				break
			}
		}

		if !isGeneral {
			// Remove any existing subnets that are more general than the new one
			for existingSubnet := range result {
				if NetIncludesNet(existingSubnet, subnet) {
					delete(result, existingSubnet)
				}
			}
			result[subnet] = score
		}
	}
	return result
}

// MaskIP returns the network address for an IP and mask size.
// Example: MaskIP("192.168.156.123", 16) -> "192.168.0.0"
func maskIP(ipStr string, maskBits int) (string, error) {
	ip := net.ParseIP(ipStr)
	if ip == nil {
		return "", fmt.Errorf("invalid IP address: %s", ipStr)
	}

	ip = ip.To4()
	if ip == nil {
		return "", fmt.Errorf("only IPv4 is supported")
	}

	mask := net.CIDRMask(maskBits, 32)
	network := ip.Mask(mask)

	return fmt.Sprintf("%s/%d", network.String(), maskBits), nil
}

// NetIncludesNet checks if one network includes another.
// Example: NetIncludesNet("192.168.0.0/16", "192.168.156.123/32") -> true
func NetIncludesNet(net1 string, net2 string) bool {
	_, ip1Net, err := net.ParseCIDR(net1)
	if err != nil {
		return false
	}
	_, ip2Net, err := net.ParseCIDR(net2)
	if err != nil {
		return false
	}

	mask1, _ := ip1Net.Mask.Size()
	mask2, _ := ip2Net.Mask.Size()

	return mask1 <= mask2 && ip1Net.Contains(ip2Net.IP)
}

func (h *packetLoggerHoneypot) Ports() []uint16 {
	return nil
}