internal/metrics/metrics.go

package metrics

import (
	"honeypot/internal/database"
	"honeypot/internal/types"
	"net/http"
	"os"
	"sort"
	"strconv"
	"strings"
	"sync"
	"unicode/utf8"

	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/promhttp"
)

const (
	topNCount = 20 // Number of top items to track
)

// SanitizeUTF8 ensures a string is valid UTF-8 for Prometheus labels.
// Invalid UTF-8 sequences are replaced with the Unicode replacement character (U+FFFD).
func SanitizeUTF8(s string) string {
	if utf8.ValidString(s) {
		return s
	}
	// Replace invalid UTF-8 sequences with Unicode replacement character
	return strings.ToValidUTF8(s, "\uFFFD")
}

// MetricsCollector collects and exposes honeypot metrics.
type MetricsCollector struct {
	disableHWMetrics bool

	// Prometheus metrics
	packetsTotal *prometheus.CounterVec
	authAttempts *prometheus.CounterVec
	topFields    *prometheus.GaugeVec
	loadAvg      *prometheus.GaugeVec
	tempCPU      *prometheus.GaugeVec
	databaseRows *prometheus.GaugeVec
	databaseSize prometheus.Gauge

	// Database reference
	db *database.Database

	// Internal tracking
	fieldCounts      map[string]map[string]map[string]int64 // type -> field_name -> value -> count
	registeredFields map[string]map[string]bool             // type -> field_name -> true

	mu sync.RWMutex
}

// NewMetricsCollector creates a new metrics collector.
func NewMetricsCollector(disableHWMetrics bool) *MetricsCollector {
	mc := &MetricsCollector{
		disableHWMetrics: disableHWMetrics,
		packetsTotal: prometheus.NewCounterVec(
			prometheus.CounterOpts{
				Name: "honeypot_packets_total",
				Help: "Total number of packets/requests by honeypot type and event type",
			},
			[]string{"type", "event"},
		),
		authAttempts: prometheus.NewCounterVec(
			prometheus.CounterOpts{
				Name: "honeypot_auth_attempts_total",
				Help: "Total number of authentication attempts by honeypot type",
			},
			[]string{"type"},
		),
		topFields: prometheus.NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "honeypot_top_fields",
				Help: "Top field values by count, configurable per honeypot type",
			},
			[]string{"type", "field_name", "value"},
		),
		loadAvg: prometheus.NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "honeypot_load_avg",
				Help: "Load average of the system",
			},
			[]string{"type"},
		),
		tempCPU: prometheus.NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "honeypot_temp_cpu",
				Help: "Temperature in degrees Celsius",
			},
			[]string{"type"},
		),
		databaseRows: prometheus.NewGaugeVec(
			prometheus.GaugeOpts{
				Name: "honeypot_database_rows",
				Help: "Number of rows in database tables",
			},
			[]string{"table"},
		),
		databaseSize: prometheus.NewGauge(
			prometheus.GaugeOpts{
				Name: "honeypot_database_size_bytes",
				Help: "Total size of the database file and WAL in bytes",
			},
		),
		fieldCounts:      make(map[string]map[string]map[string]int64),
		registeredFields: make(map[string]map[string]bool),
	}
	// Register all metrics (ignore errors if already registered)
	_ = prometheus.Register(mc.packetsTotal)
	_ = prometheus.Register(mc.authAttempts)
	_ = prometheus.Register(mc.topFields)
	_ = prometheus.Register(mc.loadAvg)
	_ = prometheus.Register(mc.tempCPU)
	_ = prometheus.Register(mc.databaseRows)
	_ = prometheus.Register(mc.databaseSize)

	// Auto-register common fields that should be tracked for all honeypot types
	mc.registerCommonFields()

	return mc
}

// registerCommonFields registers common fields that should be tracked for all honeypot types.
func (mc *MetricsCollector) registerCommonFields() {
	// Track username and password for all types (they'll only be present in auth_attempt events)
	mc.RegisterTopNField("*", "username")
	mc.RegisterTopNField("*", "password")
	// Track remote_addr for all types
	mc.RegisterTopNField("*", "remote_addr")
	// Track port for packetlogger type (will be set in RecordEvent)
	mc.RegisterTopNField("packetlogger", "port")
}

// RecordEvent processes a honeypot event and updates metrics.
func (mc *MetricsCollector) RecordEvent(e types.LogEvent) {
	mc.mu.Lock()
	defer mc.mu.Unlock()

	t := string(e.Type)

	// Increment packets/requests counter
	mc.packetsTotal.WithLabelValues(t, string(e.Event)).Inc()

	// Track auth attempts
	if e.Event == types.EventAuthAttempt {
		mc.authAttempts.WithLabelValues(t).Inc()
	}

	// Track all registered fields (including common ones and type-specific ones)
	// Check both type-specific and wildcard registrations
	fieldsToTrack := make(map[string]bool)

	// Add type-specific registered fields
	if registered, ok := mc.registeredFields[t]; ok {
		for fieldName := range registered {
			fieldsToTrack[fieldName] = true
		}
	}

	// Add wildcard registered fields
	if registered, ok := mc.registeredFields["*"]; ok {
		for fieldName := range registered {
			fieldsToTrack[fieldName] = true
		}
	}

	// Track all registered fields
	for fieldName := range fieldsToTrack {
		var value any
		var ok bool

		// Special handling for fields that are tracked but not stored in the Fields map
		if fieldName == "remote_addr" {
			value = e.RemoteAddr
			ok = e.RemoteAddr != ""
		} else if fieldName == "port" && e.Type == types.HoneypotTypePacketLogger && e.Event != types.EventICMPPacket {
			var protocol string
			switch e.Event {
			case types.EventTCPPacket:
				protocol = "tcp"
			case types.EventUDPPacket:
				protocol = "udp"
			}
			if protocol != "" {
				value = protocol + ":" + strconv.Itoa(int(e.DstPort))
				ok = true
			}
		} else if e.Fields != nil {
			value, ok = e.Fields[fieldName]
		}

		if ok {
			var strValue string
			switch v := value.(type) {
			case string:
				strValue = v
			default:
				// Skip non-string values
				continue
			}

			if strValue != "" {
				if mc.fieldCounts[t] == nil {
					mc.fieldCounts[t] = make(map[string]map[string]int64)
				}
				if mc.fieldCounts[t][fieldName] == nil {
					mc.fieldCounts[t][fieldName] = make(map[string]int64)
				}
				mc.fieldCounts[t][fieldName][strValue]++
			}
		}
	}
}

// RegisterTopNField registers a field for top-N tracking for a specific honeypot type.
func (mc *MetricsCollector) RegisterTopNField(honeypotType types.HoneypotType, fieldName string) {
	mc.mu.Lock()
	defer mc.mu.Unlock()

	if mc.registeredFields[string(honeypotType)] == nil {
		mc.registeredFields[string(honeypotType)] = make(map[string]bool)
	}
	mc.registeredFields[string(honeypotType)][fieldName] = true
}

// updateTopNGauges updates the Prometheus gauges with top-N items.
func (mc *MetricsCollector) updateTopNGauges() {
	mc.updateTopFields()
}

// updateTopFields updates the top fields gauge for all registered fields.
func (mc *MetricsCollector) updateTopFields() {
	// Reset all gauges first
	mc.topFields.Reset()

	// Process each honeypot type
	for hType, fieldNames := range mc.fieldCounts {
		// Process each field name for this honeypot type
		for fieldName, values := range fieldNames {
			// Get top N values for this field
			topValues := getTopN(values, topNCount)
			for value, count := range topValues {
				mc.topFields.WithLabelValues(
					SanitizeUTF8(hType),
					SanitizeUTF8(fieldName),
					SanitizeUTF8(value),
				).Set(float64(count))
			}
		}
	}
}

// updateLoadAvg updates the load average gauge.
func (mc *MetricsCollector) updateLoadAvg() {
	// read load average from /proc/loadavg
	loadAvg, err := os.ReadFile("/proc/loadavg")
	if err != nil {
		return
	}
	loadAvgStr := strings.Split(string(loadAvg), " ")
	loadAvg1 := loadAvgStr[0]
	loadAvg5 := loadAvgStr[1]
	loadAvg15 := loadAvgStr[2]

	loadAvg1Float, err := strconv.ParseFloat(loadAvg1, 64)
	if err != nil {
		return
	}
	loadAvg5Float, err := strconv.ParseFloat(loadAvg5, 64)
	if err != nil {
		return
	}
	loadAvg15Float, err := strconv.ParseFloat(loadAvg15, 64)
	if err != nil {
		return
	}

	mc.loadAvg.WithLabelValues("loadavg1").Set(loadAvg1Float)
	mc.loadAvg.WithLabelValues("loadavg5").Set(loadAvg5Float)
	mc.loadAvg.WithLabelValues("loadavg15").Set(loadAvg15Float)
}

// updateTempCPU updates the temperature of the CPU.
func (mc *MetricsCollector) updateTempCPU() {
	// read temperature from /sys/class/thermal/thermal_zone0/temp
	temp, err := os.ReadFile("/sys/class/thermal/thermal_zone0/temp")
	if err != nil {
		return
	}

	tempStr := strings.TrimSpace(string(temp))
	tempFloat, err := strconv.ParseFloat(tempStr, 32)
	if err != nil {
		return
	}

	mc.tempCPU.WithLabelValues("cpu").Set(tempFloat / 1000)
}

// SetDatabase sets the database reference for metrics collection.
func (mc *MetricsCollector) SetDatabase(db *database.Database) {
	mc.mu.Lock()
	defer mc.mu.Unlock()
	mc.db = db
}

// updateDatabaseStats updates the database statistics gauges.
func (mc *MetricsCollector) updateDatabaseStats() {
	mc.mu.RLock()
	db := mc.db
	mc.mu.RUnlock()

	if db == nil {
		return
	}

	stats, err := db.GetDatabaseStats()
	if err != nil {
		return
	}

	mc.databaseRows.WithLabelValues("events").Set(float64(stats.RowsEvents))
	mc.databaseRows.WithLabelValues("ips").Set(float64(stats.RowsIps))
	mc.databaseRows.WithLabelValues("blocklist").Set(float64(stats.RowsBlocklist))
	mc.databaseRows.WithLabelValues("unresolved_ips").Set(float64(stats.RowsUnresolvedIPs))
	mc.databaseSize.Set(float64(stats.DatabaseSize))
}

// getTopN returns the top N items from a map sorted by count.
func getTopN(counts map[string]int64, n int) map[string]int64 {
	if len(counts) == 0 {
		return make(map[string]int64)
	}

	// Convert to slice for sorting
	type item struct {
		key   string
		count int64
	}
	items := make([]item, 0, len(counts))
	for k, v := range counts {
		items = append(items, item{k, v})
	}

	// Sort by count (descending)
	sort.Slice(items, func(i, j int) bool {
		return items[i].count > items[j].count
	})

	// Take top N
	topN := make(map[string]int64)
	limit := n
	if len(items) < limit {
		limit = len(items)
	}
	for i := 0; i < limit; i++ {
		topN[items[i].key] = items[i].count
	}

	return topN
}

// GetHandler returns the HTTP handler for Prometheus metrics.
func (mc *MetricsCollector) GetHandler() http.Handler {
	h := promhttp.Handler()

	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		mc.mu.RLock()
		mc.updateTopNGauges()
		if !mc.disableHWMetrics {
			mc.updateLoadAvg()
			mc.updateTempCPU()
		}
		mc.mu.RUnlock()

		mc.updateDatabaseStats()

		h.ServeHTTP(w, r)
	})
}