internal/honeypot/dns/dns.go

package dns

import (
	"context"
	"log/slog"
	"sync"
	"time"

	"honeypot/internal/database"
	"honeypot/internal/honeypot"
	"honeypot/internal/logger"
	"honeypot/internal/types"
	"honeypot/internal/utils"

	"github.com/miekg/dns"
)

const (
	HoneypotType    = types.HoneypotTypeDNS
	HoneypotLabel   = "DNS"
	ShutdownTimeout = 3 * time.Second
)

type Config struct {
	ListenAddr string
	Ports      []uint16
}

type dnsHoneypot struct {
	config Config
	logger *slog.Logger
}

func New(cfg Config) honeypot.Honeypot {
	return &dnsHoneypot{
		config: cfg,
	}
}

func (h *dnsHoneypot) Name() types.HoneypotType {
	return HoneypotType
}

func (h *dnsHoneypot) Label() string {
	return HoneypotLabel
}

func (h *dnsHoneypot) Start(ctx context.Context, l *slog.Logger) error {
	h.logger = l

	var wg sync.WaitGroup

	for _, port := range h.config.Ports {
		if port > 0 {
			// Start UDP server
			wg.Add(1)
			go func(p uint16) {
				defer wg.Done()
				h.runServer(ctx, p, "udp")
			}(port)

			// Start TCP server
			wg.Add(1)
			go func(p uint16) {
				defer wg.Done()
				h.runServer(ctx, p, "tcp")
			}(port)
		}
	}

	wg.Wait()
	logger.LogInfo(h.logger, HoneypotType, "honeypot shutdown complete", nil)
	return nil
}

func (h *dnsHoneypot) runServer(ctx context.Context, port uint16, netType string) {
	addr := utils.BuildAddress(h.config.ListenAddr, port)
	server := &dns.Server{
		Addr:    addr,
		Net:     netType,
		Handler: dns.HandlerFunc(h.handleQuery),
	}

	logger.LogInfo(h.logger, HoneypotType, "honeypot listening", []any{
		"port", port,
		"net", netType,
	})

	go func() {
		<-ctx.Done()
		shutdownCtx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
		defer cancel()
		if err := server.ShutdownContext(shutdownCtx); err != nil {
			logger.LogError(h.logger, HoneypotType, "server_shutdown_failed", err, []any{
				"net", netType,
				"port", port,
			})
		}
	}()

	if err := server.ListenAndServe(); err != nil {
		if ctx.Err() == nil {
			logger.LogError(h.logger, HoneypotType, "server_failed", err, []any{
				"net", netType,
				"port", port,
			})
		}
	}
}

func (h *dnsHoneypot) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
	remoteAddr := w.RemoteAddr().String()
	localAddr := w.LocalAddr().String()

	remoteHost, remotePort := utils.SplitAddr(remoteAddr, h.logger)
	_, dstPort := utils.SplitAddr(localAddr, h.logger)

	for _, q := range r.Question {
		fields := map[string]interface{}{
			"qname":  q.Name,
			"qtype":  dns.TypeToString[q.Qtype],
			"qclass": dns.ClassToString[q.Qclass],
			"id":     r.Id,
			"opcode": dns.OpcodeToString[r.Opcode],
			"rd":     r.RecursionDesired,
			"net":    w.RemoteAddr().Network(),
		}

		logger.LogEvent(h.logger, types.LogEvent{
			Type:       HoneypotType,
			Event:      types.EventDNSQuery,
			RemoteAddr: remoteHost,
			RemotePort: remotePort,
			DstPort:    dstPort,
			Fields:     fields,
		})
	}
}

func (h *dnsHoneypot) GetScores(db *database.Database, interval string) honeypot.ScoreMap {
	return honeypot.ScoreMap{}
}

func (h *dnsHoneypot) Ports() []uint16 {
	return h.config.Ports
}