internal/database/stats.go

package database

import (
	"database/sql"
	"encoding/binary"
	"errors"
	"fmt"
	"honeypot/internal/utils"
	"net"
	"net/url"
	"strings"
	"time"
)

type LabelCount struct {
	Label string `json:"label"`
	Count int    `json:"count"`
}

type ActivityCount struct {
	Time  time.Time `json:"time"`
	Type  string    `json:"type"`
	Count float64   `json:"count"`
}

// getTopNFields returns the top n fields for a given event type and field expression
func (db *Database) GetTopNFields(fieldSelect string, whereClause string, args []any, limit int) ([]LabelCount, error) {

	rows, err := db.DB.Query(fmt.Sprintf(`
		SELECT %s as field, COUNT(*) as count 
		FROM honeypot_events 
		WHERE %s AND field IS NOT NULL
		GROUP BY field 
		ORDER BY count DESC
		LIMIT ?
		`, fieldSelect, whereClause), append(args, limit)...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var counts []LabelCount
	for rows.Next() {
		var fc LabelCount
		if err := rows.Scan(&fc.Label, &fc.Count); err != nil {
			return nil, err
		}
		counts = append(counts, fc)
	}
	return counts, nil
}

// HoneypotStats contains stats for a honeypot type: top remote addrs, ports, event types
type HoneypotStats struct {
	Title          string         `json:"title,omitempty"`
	Metadata       map[string]any `json:"metadata,omitempty"`
	TotalEvents    int            `json:"total_events,omitempty"`
	FirstSeen      string         `json:"first_seen,omitempty"`
	LastSeen       string         `json:"last_seen,omitempty"`
	RemoteAddrs    []LabelCount   `json:"top_addrs"`
	Ports          []LabelCount   `json:"top_ports"`
	EventTypes     []LabelCount   `json:"top_events"`
	Methods        []LabelCount   `json:"methods,omitempty"`
	UserAgents     []LabelCount   `json:"user_agents,omitempty"`
	URIs           []LabelCount   `json:"uris,omitempty"`
	Hosts          []LabelCount   `json:"hosts,omitempty"`
	Referers       []LabelCount   `json:"referers,omitempty"`
	ClientVersions []LabelCount   `json:"client_versions,omitempty"`
	Usernames      []LabelCount   `json:"usernames,omitempty"`
	Passwords      []LabelCount   `json:"passwords,omitempty"`
	SecurityLayers []LabelCount   `json:"security_layers,omitempty"`
	FQDNs          []LabelCount   `json:"fqdns,omitempty"`
}

var httpFields = [][2]string{
	{"fields.headers.\"User-Agent\"", "user-agent"},
	{"fields.uri", "uri"},
	{"fields.host", "host"},
	{"fields.headers.\"Referer\"", "referer"},
	{"fields.username", "username"},
	{"fields.password", "password"},
	{"fields.method", "method"},
}

var sshFields = [][2]string{
	{"fields.client_version", "client_version"},
	{"fields.username", "username"},
	{"fields.password", "password"},
}

var telnetFields = [][2]string{
	{"fields.username", "username"},
	{"fields.password", "password"},
}

var ftpFields = [][2]string{
	{"fields.username", "username"},
	{"fields.password", "password"},
}

var rdpFields = [][2]string{
	{"fields.username", "username"},
	{"fields.security_layer", "security_layer"},
}

// GetHoneypotStats returns the top 50 remote addresses, top 20 ports, top 20 events for a given honeypot type
func (db *Database) GetHoneypotStats(eventType string) (*HoneypotStats, error) {
	if db.DB == nil {
		return nil, errors.New("database is not connected")
	}

	stats := &HoneypotStats{}

	var fields [][2]string
	switch eventType {
	case "http":
		fields = httpFields
	case "ssh":
		fields = sshFields
	case "telnet":
		fields = telnetFields
	case "ftp":
		fields = ftpFields
	case "rdp":
		fields = rdpFields
	}

	// Get overall stats
	var firstSeen, lastSeen sql.NullTime
	err := db.DB.QueryRow(`
		SELECT COUNT(*), MIN(time), MAX(time) 
		FROM honeypot_events 
		WHERE type = ?`, eventType).Scan(&stats.TotalEvents, &firstSeen, &lastSeen)
	if err != nil {
		return nil, err
	}
	if firstSeen.Valid {
		stats.FirstSeen = firstSeen.Time.Format(time.RFC3339Nano)
	}
	if lastSeen.Valid {
		stats.LastSeen = lastSeen.Time.Format(time.RFC3339Nano)
	}

	fields = append(fields, [2]string{"event", "event"})
	fields = append(fields, [2]string{"remote_addr", "remote_addr"})
	fields = append(fields, [2]string{"dst_port", "dst_port"})

	for _, fieldName := range fields {
		counts, err := db.GetTopNFields(fieldName[0], "type = ?", []any{eventType}, 20)
		if err != nil {
			return nil, err
		}

		for _, c := range counts {
			switch fieldName[1] {
			case "event":
				stats.EventTypes = append(stats.EventTypes, LabelCount{Label: c.Label, Count: c.Count})
			case "remote_addr":
				stats.RemoteAddrs = append(stats.RemoteAddrs, LabelCount{Label: c.Label, Count: c.Count})
			case "dst_port":
				stats.Ports = append(stats.Ports, LabelCount{Label: c.Label, Count: c.Count})
			case "method":
				stats.Methods = append(stats.Methods, LabelCount{Label: c.Label, Count: c.Count})
			case "user-agent":
				stats.UserAgents = append(stats.UserAgents, LabelCount{Label: c.Label, Count: c.Count})
			case "uri":
				stats.URIs = append(stats.URIs, LabelCount{Label: c.Label, Count: c.Count})
			case "host":
				stats.Hosts = append(stats.Hosts, LabelCount{Label: c.Label, Count: c.Count})
			case "referer":
				stats.Referers = append(stats.Referers, LabelCount{Label: c.Label, Count: c.Count})
			case "client_version":
				stats.ClientVersions = append(stats.ClientVersions, LabelCount{Label: c.Label, Count: c.Count})
			case "username":
				stats.Usernames = append(stats.Usernames, LabelCount{Label: c.Label, Count: c.Count})
			case "password":
				stats.Passwords = append(stats.Passwords, LabelCount{Label: c.Label, Count: c.Count})
			case "security_layer":
				stats.SecurityLayers = append(stats.SecurityLayers, LabelCount{Label: c.Label, Count: c.Count})
			}
		}
	}

	return stats, nil
}

// GetFirstLastSeenTotalForNet returns the first and last seen timestamps and total count for a given network
func (db *Database) GetFirstLastSeenTotalForNet(network string) (string, string, int, error) {
	if db.DB == nil {
		return "", "", 0, errors.New("database is not connected")
	}

	var where string
	var args []any
	var err error
	if where, args, err = GetIpWhere(network); err != nil {
		return "", "", 0, err
	}

	query := fmt.Sprintf(`
		SELECT COUNT(*) as count, MIN(time) as first_seen, MAX(time) as last_seen
		FROM honeypot_events
		WHERE %s
	`, where)

	row := db.DB.QueryRow(query, args...)
	var count int
	var firstSeen, lastSeen sql.NullTime
	err = row.Scan(&count, &firstSeen, &lastSeen)
	if err != nil {
		return "", "", 0, err
	}

	firstSeenStr := ""
	if firstSeen.Valid {
		firstSeenStr = firstSeen.Time.Format(time.RFC3339Nano)
	}

	lastSeenStr := ""
	if lastSeen.Valid {
		lastSeenStr = lastSeen.Time.Format(time.RFC3339Nano)
	}

	return firstSeenStr, lastSeenStr, count, nil
}

// GetPortStatsOverview returns the first and last seen timestamps and total count for a given port
func (db *Database) GetPortStatsOverview(port int) (string, string, int, error) {
	if db.DB == nil {
		return "", "", 0, errors.New("database is not connected")
	}

	query := `
		SELECT COUNT(*) as count, MIN(time) as first_seen, MAX(time) as last_seen
		FROM honeypot_events
		WHERE dst_port = ?
	`

	row := db.DB.QueryRow(query, port)
	var count int
	var firstSeen, lastSeen sql.NullTime
	err := row.Scan(&count, &firstSeen, &lastSeen)
	if err != nil {
		return "", "", 0, err
	}

	firstSeenStr := ""
	if firstSeen.Valid {
		firstSeenStr = firstSeen.Time.Format(time.RFC3339Nano)
	}

	lastSeenStr := ""
	if lastSeen.Valid {
		lastSeenStr = lastSeen.Time.Format(time.RFC3339Nano)
	}

	return firstSeenStr, lastSeenStr, count, nil
}

// GetSubnetStats returns the addresses and counts for a given subnet
func (db *Database) GetSubnetStats(subnet string) ([]LabelCount, error) {
	if db.DB == nil {
		return nil, errors.New("database is not connected")
	}

	var where string
	var args []any
	var err error
	if where, args, err = GetIpWhere(subnet); err != nil {
		return nil, err
	}

	query := fmt.Sprintf(`
		SELECT remote_addr, COUNT(*) as count 
		FROM honeypot_events 
		WHERE %s
		GROUP BY remote_addr 
		ORDER BY count DESC, remote_addr ASC
	`, where)
	rows, err := db.DB.Query(query, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var stats []LabelCount
	for rows.Next() {
		var ac LabelCount
		if err := rows.Scan(&ac.Label, &ac.Count); err != nil {
			return nil, err
		}
		stats = append(stats, ac)
	}
	return stats, nil
}

// GetIpWhere returns the where clause and arguments for a given IP or subnet
func GetIpWhere(ip string) (string, []any, error) {
	var where string
	var args []any

	if strings.Contains(ip, "/") {
		_, ipNet, err := net.ParseCIDR(ip)
		if err == nil && ipNet.IP.To4() != nil {
			// IPv4 Subnet
			mask := binary.BigEndian.Uint32(ipNet.Mask)
			start := binary.BigEndian.Uint32(ipNet.IP.To4())
			end := start | ^mask
			where = "remote_ip_int >= ? AND remote_ip_int <= ?"
			args = []any{start, end}
		} else {
			// Fallback to string matching for IPv6 or if parsing failed
			where = "remote_addr::INET <<= ?::INET"
			args = []any{ip}
		}
	} else {
		where = "remote_ip_int = ?"
		ipInt, err := utils.IPToInt(ip)
		if err != nil {
			return "", nil, err
		}
		args = []any{ipInt}
	}
	return where, args, nil
}

// GetActivityOverTime returns the event count per time slot and honeypot type
func (db *Database) GetActivityOverTime(q url.Values) ([]ActivityCount, error) {
	if db.DB == nil {
		return nil, errors.New("database is not connected")
	}

	query, err := db.parseEventQuery(q)
	if err != nil {
		return nil, err
	}

	bucketInterval, bucketMinutes := db.calculateBucketInterval(query.TimeStart, query.TimeEnd)

	hasGeo := len(query.ASNs) > 0 || len(query.Countries) > 0 || len(query.Cities) > 0 || len(query.Domains) > 0 || len(query.FQDNs) > 0

	whereClauses, args := buildWhereClauses(query, hasGeo)
	whereStr := ""
	if len(whereClauses) > 0 {
		whereStr = "WHERE " + strings.Join(whereClauses, " AND ")
	}

	joinStr := ""
	typeField := "type"
	timeField := "time"
	if hasGeo {
		joinStr = "JOIN ips ON honeypot_events.remote_addr = ips.ip"
		typeField = "honeypot_events.type"
		timeField = "honeypot_events.time"
	}

	sql := fmt.Sprintf(`
		SELECT time_bucket(INTERVAL '%s', %s) AS bucket, %s, CAST(COUNT(*) AS FLOAT) / %f as count
		FROM honeypot_events
		%s
		%s
		GROUP BY bucket, honeypot_events.type
		ORDER BY bucket ASC, count DESC
	`, bucketInterval, timeField, typeField, bucketMinutes, joinStr, whereStr)

	rows, err := db.DB.Query(sql, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var activity []ActivityCount
	for rows.Next() {
		var ac ActivityCount
		if err := rows.Scan(&ac.Time, &ac.Type, &ac.Count); err != nil {
			return nil, err
		}
		activity = append(activity, ac)
	}

	return activity, nil
}

func (db *Database) calculateBucketInterval(start, end time.Time) (string, float64) {
	if start.IsZero() {
		return "1 HOUR", 60.0
	}

	if end.IsZero() {
		end = time.Now()
	}

	duration := end.Sub(start)
	if duration < 0 {
		duration = -duration
	}

	switch {
	case duration <= 24*time.Hour:
		return "1 MINUTE", 1.0
	case duration <= 7*24*time.Hour:
		return "5 MINUTES", 5.0
	case duration <= 14*24*time.Hour:
		return "15 MINUTES", 15.0
	case duration <= 30*24*time.Hour:
		return "1 HOUR", 60.0
	case duration <= 6*30*24*time.Hour:
		return "6 HOURS", 360.0
	case duration <= 12*30*24*time.Hour:
		return "12 HOURS", 720.0
	default:
		return "1 DAY", 1440.0
	}
}