internal/database/ip_info.go

package database

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"honeypot/internal/geodb"
	"honeypot/internal/utils"
	"log"
	"log/slog"
	"strconv"
	"strings"
	"sync"
	"time"
)

// GetUniqueIPsFromEvents returns a list of unique remote addresses from the honeypot_events table.
// If includeFresh is false, it only returns IPs that are either missing from the ips table
// or have metadata older than the given duration.
func (db *Database) GetUniqueIPsFromEvents(includeFresh bool, olderThan time.Duration) ([]string, error) {
	query := "SELECT DISTINCT remote_addr FROM honeypot_events"
	var args []any

	if !includeFresh {
		query = `
			SELECT DISTINCT remote_addr 
			FROM honeypot_events 
			WHERE remote_addr NOT IN (
				SELECT ip FROM ips WHERE last_updated > ?
			)
		`
		args = append(args, time.Now().Add(-olderThan))
	}

	query += " ORDER BY remote_addr"

	rows, err := db.DB.Query(query, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var ips []string
	for rows.Next() {
		var ip string
		if err := rows.Scan(&ip); err != nil {
			return nil, err
		}
		ips = append(ips, ip)
	}
	return ips, nil
}

// IPMetadata represents the metadata for an IP address.
type IPMetadata struct {
	IP          string    `json:"ip"`
	IPInt       uint32    `json:"ip_int"`
	Country     string    `json:"country"`
	ASN         int       `json:"asn"`
	ASNOrg      string    `json:"asn_org"`
	City        string    `json:"city"`
	Latitude    float64   `json:"latitude"`
	Longitude   float64   `json:"longitude"`
	FQDN        string    `json:"fqdn"`
	Domain      string    `json:"domain"`
	LastUpdated time.Time `json:"last_updated"`
}

var upsertMu sync.Mutex

// UpsertIPMetadata inserts or updates IP metadata in the ips table.
func (db *Database) UpsertIPMetadata(m *IPMetadata) error {
	upsertMu.Lock()
	defer upsertMu.Unlock()

	now := time.Now()
	_, err := db.DB.Exec(`
		INSERT INTO ips (
			ip, ip_int, country, asn, asn_org, city, 
			latitude, longitude, fqdn, domain, last_updated
		) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
		ON CONFLICT (ip) DO UPDATE SET
			ip_int = excluded.ip_int,
			country = excluded.country,
			asn = excluded.asn,
			asn_org = excluded.asn_org,
			city = excluded.city,
			latitude = excluded.latitude,
			longitude = excluded.longitude,
			fqdn = excluded.fqdn,
			domain = excluded.domain,
			last_updated = excluded.last_updated
	`, m.IP, m.IPInt, m.Country, m.ASN, m.ASNOrg, m.City,
		m.Latitude, m.Longitude, m.FQDN, m.Domain, now)
	return err
}

// GetIPMetadata retrieves IP metadata from the ips table for a list of IPs.
func (db *Database) GetIPMetadata(ips []string) (map[string]*IPMetadata, error) {
	if len(ips) == 0 {
		return make(map[string]*IPMetadata), nil
	}

	placeholders := strings.TrimRight(strings.Repeat("?, ", len(ips)), ", ")
	query := fmt.Sprintf(`
		SELECT ip, ip_int, country, asn, asn_org, city, 
		       latitude, longitude, fqdn, domain, last_updated
		FROM ips WHERE ip IN (%s)
	`, placeholders)

	args := make([]any, len(ips))
	for i, ip := range ips {
		args[i] = ip
	}

	rows, err := db.DB.Query(query, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	results := make(map[string]*IPMetadata)
	for rows.Next() {
		var m IPMetadata
		err := rows.Scan(
			&m.IP, &m.IPInt, &m.Country, &m.ASN, &m.ASNOrg, &m.City,
			&m.Latitude, &m.Longitude, &m.FQDN, &m.Domain, &m.LastUpdated,
		)
		if err != nil {
			return nil, err
		}
		results[m.IP] = &m
	}
	return results, nil
}

var isUpdatingIPInfo bool
var updateMu sync.Mutex

// UpdateIPInfo updates the metadata for all unique IPs in the database.
// It uses the provided GeoDB to look up metadata and upserts it into the database.
// This function is thread-safe and will only run one update at a time.
func (db *Database) UpdateIPInfo(ctx context.Context, geoDb *geodb.GeoDB, l *slog.Logger) {
	updateMu.Lock()
	if isUpdatingIPInfo {
		updateMu.Unlock()
		return
	}
	isUpdatingIPInfo = true
	updateMu.Unlock()

	defer func() {
		updateMu.Lock()
		isUpdatingIPInfo = false
		updateMu.Unlock()
	}()

	ips, err := db.GetUniqueIPsFromEvents(false, 3*24*time.Hour)
	if err != nil {
		l.Error("failed to get unique IPs from events", "error", err)
		return
	}

	var wg sync.WaitGroup
	sem := make(chan struct{}, 50)

ProcessingLoop:
	for _, ip := range ips {
		select {
		case <-ctx.Done():
			break ProcessingLoop
		case sem <- struct{}{}:
			wg.Add(1)
			go func(ip string) {
				defer wg.Done()
				defer func() { <-sem }()

				meta, err := geoDb.LookupMetadata(ctx, ip)
				if err != nil {
					return
				}

				dbMeta := &IPMetadata{
					IP:        meta.IP,
					IPInt:     uint32(0),
					Country:   meta.CountryCode,
					ASN:       meta.ASN,
					ASNOrg:    meta.ASNOrg,
					City:      meta.City,
					Latitude:  meta.Latitude,
					Longitude: meta.Longitude,
					FQDN:      meta.FQDN,
					Domain:    meta.Domain,
				}

				// Calculate IP Int if possible
				ipInt, err := utils.IPToInt(ip)
				if err == nil {
					dbMeta.IPInt = ipInt
				}

				if err := db.UpsertIPMetadata(dbMeta); err != nil {
					l.Error("failed to upsert IP metadata", "ip", ip, "error", err)
				}
			}(ip)
		}
	}
	wg.Wait()
	db.IPInfoLastRun = time.Now()
}

// GetGeoStats returns the top 50 remote addresses, top 20 ports, top 20 events for a given geo location
func (db *Database) GetGeoStats(geoType string, value string) (*HoneypotStats, error) {
	if db.DB == nil {
		return nil, errors.New("database is not connected")
	}

	stats := &HoneypotStats{}

	where, args, err := GetGeoWhere(geoType, value)
	if err != nil {
		return nil, err
	}

	// Set title and metadata
	stats.Metadata = make(map[string]any)
	value = strings.TrimSpace(value)
	switch geoType {
	case "asn":
		var org, country string
		// Try to find any record for this ASN that has an organization name
		query := "SELECT asn_org, country FROM ips WHERE asn = ? AND asn_org IS NOT NULL AND asn_org != '' LIMIT 1"
		err := db.DB.QueryRow(query, value).Scan(&org, &country)
		if err != nil {
			// Try as integer if string comparison failed
			asnInt, _ := strconv.Atoi(value)
			db.DB.QueryRow(query, asnInt).Scan(&org, &country)
		}

		stats.Title = "ASN " + value
		if org != "" {
			stats.Metadata["asn_org"] = org
		}
		if country != "" {
			stats.Metadata["country"] = country
		}
	case "country":
		var country string
		db.DB.QueryRow("SELECT DISTINCT country FROM ips WHERE country = ? AND country IS NOT NULL LIMIT 1", value).Scan(&country)
		if country != "" {
			stats.Title = country
		} else {
			stats.Title = value
		}
		stats.Metadata["country"] = value
	case "city":
		var country string
		db.DB.QueryRow("SELECT DISTINCT country FROM ips WHERE city = ? AND country IS NOT NULL LIMIT 1", value).Scan(&country)
		stats.Title = value
		if country != "" {
			stats.Metadata["country"] = country
		}
	case "domain":
		stats.Title = value
	case "fqdn":
		stats.Title = value
	}

	// Get total events, first seen, last seen
	var firstSeen, lastSeen sql.NullTime
	err = db.DB.QueryRow(fmt.Sprintf(`
		SELECT COUNT(*), MIN(time), MAX(time)
		FROM honeypot_events
		JOIN ips ON honeypot_events.remote_addr = ips.ip
		WHERE %s
	`, where), args...).Scan(&stats.TotalEvents, &firstSeen, &lastSeen)
	if err != nil {
		log.Printf("failed to get geo event overview: %v", err)
	}
	if firstSeen.Valid {
		stats.FirstSeen = firstSeen.Time.Format(time.RFC3339Nano)
	}
	if lastSeen.Valid {
		stats.LastSeen = lastSeen.Time.Format(time.RFC3339Nano)
	}

	fields := [][2]string{
		{"event", "event"},
		{"remote_addr", "remote_addr"},
		{"dst_port", "dst_port"},
		{"type", "type"},
		{"ips.fqdn", "fqdn"},
	}

	// We need to use JOIN for geo stats
	for _, fieldName := range fields {
		selectField := fieldName[0]
		if selectField != "fields" && !strings.Contains(selectField, "ips.") {
			selectField = "honeypot_events." + selectField
		}

		func() {
			rows, err := db.DB.Query(fmt.Sprintf(`
				SELECT %s as field, COUNT(*) as count 
				FROM honeypot_events 
				JOIN ips ON honeypot_events.remote_addr = ips.ip
				WHERE %s AND field IS NOT NULL
				GROUP BY field 
				ORDER BY count DESC
				LIMIT ?
				`, selectField, where), append(args, 20)...)
			if err != nil {
				log.Printf("failed to query geo stats for field %s: %v", fieldName[1], err)
				return
			}
			defer rows.Close()

			for rows.Next() {
				var fc LabelCount
				if err := rows.Scan(&fc.Label, &fc.Count); err != nil {
					log.Printf("failed to scan geo stats result for field %s: %v", fieldName[1], err)
					return
				}
				switch fieldName[1] {
				case "event":
					stats.EventTypes = append(stats.EventTypes, fc)
				case "remote_addr":
					stats.RemoteAddrs = append(stats.RemoteAddrs, fc)
				case "dst_port":
					stats.Ports = append(stats.Ports, fc)
				case "type":
					// We reuse EventTypes or add a new field if needed, but for now let's just use what's there
				case "fqdn":
					stats.FQDNs = append(stats.FQDNs, fc)
				}
			}
		}()
	}

	return stats, nil
}

// GetGeoWhere returns the where clause and arguments for a given geo type and value
func GetGeoWhere(geoType string, value string) (string, []any, error) {
	var where string
	var args []any

	switch geoType {
	case "asn":
		where = "ips.asn = ?"
	case "country":
		where = "ips.country = ?"
	case "city":
		where = "ips.city = ?"
	case "domain":
		where = "ips.domain = ?"
	case "fqdn":
		where = "ips.fqdn = ?"
	default:
		return "", nil, fmt.Errorf("invalid geo type: %s", geoType)
	}

	args = []any{value}
	return where, args, nil
}