internal/database/database.go

package database

import (
	"database/sql"
	"encoding/json"
	"errors"
	"fmt"
	"honeypot/internal/types"
	"honeypot/internal/utils"
	"log"
	"net/url"
	"os"
	"slices"
	"strconv"
	"strings"
	"time"

	"github.com/duckdb/duckdb-go/v2"
)

type Database struct {
	DB            *sql.DB
	Connector     *duckdb.Connector
	Path          string
	IPInfoLastRun time.Time
}

// NewDatabase creates a new duckdb database connection.
// Returns nil if the database file is not set.
func NewDatabase(databaseFile string) *Database {
	if databaseFile == "" {
		return nil
	}
	dbExists := false
	if databaseFile != "" {
		dbExists = true

		if _, err := os.Stat(databaseFile); errors.Is(err, os.ErrNotExist) {
			dbExists = false
			fmt.Println("Database file does not exist, creating new database")
		} else {
			fmt.Println("Database file exists, using existing database")
		}
	}

	// Create a connector for appender support
	connector, err := duckdb.NewConnector(databaseFile, nil)
	if err != nil {
		log.Fatalf("failed to create duckdb connector: %v", err)
	}

	db := sql.OpenDB(connector)

	database := &Database{DB: db, Connector: connector, Path: databaseFile}

	if !dbExists {
		fmt.Println("Creating database tables")
		if err := database.CreateTables(); err != nil {
			log.Printf("failed to create database tables: %v", err)
		}
	}

	return database
}

func (db *Database) CreateTables() error {
	_, err := db.DB.Exec(`
		CREATE SEQUENCE IF NOT EXISTS id_sequence START 1;

		CREATE TABLE IF NOT EXISTS honeypot_events (
			id INTEGER PRIMARY KEY DEFAULT nextval('id_sequence'),
			time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
			type TEXT NOT NULL,
			event TEXT NOT NULL,
			remote_addr TEXT NOT NULL,
			remote_ip_int UBIGINT,
			remote_port USMALLINT NOT NULL,
			dst_port USMALLINT NOT NULL,
			fields JSON
		);

		CREATE INDEX IF NOT EXISTS idx_events_type ON honeypot_events (type);
		CREATE INDEX IF NOT EXISTS idx_events_event ON honeypot_events (event);
		CREATE INDEX IF NOT EXISTS idx_events_remote_ip_int ON honeypot_events (remote_ip_int);
		CREATE INDEX IF NOT EXISTS idx_events_dst_port ON honeypot_events (dst_port);
		CREATE INDEX IF NOT EXISTS idx_events_time ON honeypot_events (time);


		CREATE SEQUENCE IF NOT EXISTS blocklist_id_sequence START 1;

		CREATE TABLE IF NOT EXISTS blocklist (
		    id INTEGER PRIMARY KEY DEFAULT nextval('blocklist_id_sequence'),
			address TEXT NOT NULL,
			timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
			expires TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '3 HOURS',
			reason TEXT NOT NULL,
		);

		CREATE TABLE IF NOT EXISTS ips (
			ip TEXT PRIMARY KEY,
			ip_int UBIGINT,
			country TEXT,
			asn INTEGER,
			asn_org TEXT,
			city TEXT,
			latitude DOUBLE,
			longitude DOUBLE,
			fqdn TEXT,
			domain TEXT,
			last_updated TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
		);

		CREATE INDEX IF NOT EXISTS idx_ips_ip_int ON ips (ip_int);
		CREATE INDEX IF NOT EXISTS idx_ips_country ON ips (country);
		CREATE INDEX IF NOT EXISTS idx_ips_asn ON ips (asn);
	`)
	return err
}

// Close closes the database connection.
func (db *Database) Close() error {
	var errs []error
	if db.DB != nil {
		if err := db.DB.Close(); err != nil {
			errs = append(errs, err)
		}
	}
	if db.Connector != nil {
		if err := db.Connector.Close(); err != nil {
			errs = append(errs, err)
		}
	}
	if len(errs) > 0 {
		return errors.Join(errs...)
	}
	return nil
}

// Checkpoint flushes the write-ahead log to the database file.
func (db *Database) Checkpoint() error {
	if db.DB == nil {
		return errors.New("database is not connected")
	}
	_, err := db.DB.Exec("FORCE CHECKPOINT")
	return err
}

func (db *Database) InsertEvent(event *types.LogEvent) error {
	if db.DB == nil {
		return errors.New("database is not connected")
	}

	var fields any
	if event.Fields != nil {
		var err error
		fields, err = json.Marshal(event.Fields)
		if err != nil {
			return err
		}
	}

	remoteIPInt, err := utils.IPToInt(event.RemoteAddr)
	if err != nil {
		return err
	}

	_, err = db.DB.Exec(`
		INSERT INTO honeypot_events (id, time, type, event, remote_addr, remote_ip_int, remote_port, dst_port, fields)
		VALUES (NEXTVAL('id_sequence'), ?, ?, ?, ?, ?, ?, ?, ?)
	`, event.Time, event.Type, event.Event, event.RemoteAddr, remoteIPInt, event.RemotePort, event.DstPort, fields)

	return err
}

type DatabaseStats struct {
	RowsEvents        int   `json:"rows_events"`
	RowsBlocklist     int   `json:"rows_blocklist"`
	RowsIps           int   `json:"rows_ips"`
	RowsUnresolvedIPs int   `json:"rows_unresolved_ips"`
	DatabaseSize      int64 `json:"database_size"`
	WalSize           int64 `json:"wal_size"`
}

func (db *Database) GetDatabaseStats() (DatabaseStats, error) {
	var stats DatabaseStats

	if db.DB == nil {
		return stats, errors.New("database is not connected")
	}

	// Get row counts
	err := db.DB.QueryRow(`
		SELECT 
			(SELECT COUNT(*) FROM honeypot_events), 
			(SELECT COUNT(*) FROM blocklist), 
			(SELECT COUNT(*) FROM ips),
			(SELECT COUNT(DISTINCT remote_addr) FROM honeypot_events WHERE remote_addr NOT IN (SELECT ip FROM ips WHERE last_updated > ?))
	`, time.Now().Add(-72*time.Hour)).Scan(
		&stats.RowsEvents, &stats.RowsBlocklist, &stats.RowsIps, &stats.RowsUnresolvedIPs,
	)
	if err != nil {
		return stats, err
	}

	// Get file size
	if db.Path != "" {
		if info, err := os.Stat(db.Path); err == nil {
			stats.DatabaseSize = info.Size()
		}
		// Also add WAL size
		walPath := db.Path + ".wal"
		if info, err := os.Stat(walPath); err == nil {
			stats.WalSize = info.Size()
			stats.DatabaseSize += stats.WalSize
		}
	}

	return stats, nil
}

func (db *Database) InsertBlocklist(address, reason string, duration time.Duration) error {
	expires := time.Now().Add(duration)
	_, err := db.DB.Exec("INSERT INTO blocklist (address, expires, reason) VALUES (?, ?, ?)", address, expires, reason)
	return err
}

type PortCount struct {
	Port  int `json:"port"`
	Count int `json:"count"`
}

type JSONMap map[string]any

func (m *JSONMap) Scan(value any) error {
	if value == nil {
		*m = nil
		return nil
	}
	switch v := value.(type) {
	case []byte:
		return json.Unmarshal(v, m)
	case string:
		return json.Unmarshal([]byte(v), m)
	case map[string]any:
		*m = JSONMap(v)
		return nil
	default:
		return fmt.Errorf("unexpected type for JSONMap: %T", value)
	}
}

type Event struct {
	ID         int       `json:"id,omitempty"`
	Time       time.Time `json:"time,omitzero"`
	Type       string    `json:"type,omitempty"`
	Event      string    `json:"event,omitempty"`
	RemoteAddr string    `json:"remote_addr,omitempty"`
	RemotePort uint16    `json:"remote_port,omitempty"`
	DstPort    uint16    `json:"dst_port,omitempty"`
	Fields     JSONMap   `json:"fields,omitempty"`
	Country    string    `json:"country,omitempty"`
	City       string    `json:"city,omitempty"`
	Latitude   float64   `json:"latitude,omitempty"`
	Longitude  float64   `json:"longitude,omitempty"`
}

type QueryResponse struct {
	Query     string  `json:"query"`
	WhereArgs []any   `json:"where_args"`
	Events    []Event `json:"events"`
	Total     int     `json:"total"` // total number of events matching the query without pagination
	QueryTime string  `json:"query_time"`
}

type EventQuery struct {
	Limit          int
	Offset         int
	OrderDirection string
	RemoteAddrs    []string
	RemotePorts    []string
	DstPorts       []string
	Event          []string
	IDs            []string
	TimeStart      time.Time
	TimeEnd        time.Time
	Type           []string
	FieldFilters   map[string][]string
	FieldExists    []string
	Columns        []string
	ASNs           []string
	Countries      []string
	Cities         []string
	Domains        []string
	FQDNs          []string
}

type eventField int

const (
	fieldID   eventField = iota
	fieldTime eventField = iota + 1
	fieldType
	fieldEvent
	fieldRemoteAddr
	fieldRemotePort
	fieldDstPort
	fieldFields
	fieldCountry
	fieldCity
	fieldLatitude
	fieldLongitude
)

func buildFieldPlan(q EventQuery) ([]eventField, error) {
	hasGeo := HasGeoFields(q)

	allowedColumns := []string{"id", "time", "type", "event", "remote_addr", "remote_port", "dst_port", "fields", "country", "city", "latitude", "longitude"}
	validColumns := make([]string, 0, len(q.Columns))
	for _, col := range q.Columns {
		if slices.Contains(allowedColumns, col) {
			validColumns = append(validColumns, col)
		}
	}

	if len(validColumns) == 0 { // return all fields
		if hasGeo {
			return []eventField{fieldID, fieldTime, fieldType, fieldEvent, fieldRemoteAddr, fieldRemotePort, fieldDstPort, fieldFields, fieldCountry, fieldCity, fieldLatitude, fieldLongitude}, nil
		}
		return []eventField{fieldID, fieldTime, fieldType, fieldEvent, fieldRemoteAddr, fieldRemotePort, fieldDstPort, fieldFields}, nil
	}

	plan := make([]eventField, len(validColumns))
	for i, col := range validColumns {
		switch col {
		case "id":
			plan[i] = fieldID
		case "time":
			plan[i] = fieldTime
		case "type":
			plan[i] = fieldType
		case "event":
			plan[i] = fieldEvent
		case "remote_addr":
			plan[i] = fieldRemoteAddr
		case "remote_port":
			plan[i] = fieldRemotePort
		case "dst_port":
			plan[i] = fieldDstPort
		case "fields":
			plan[i] = fieldFields
		case "country":
			plan[i] = fieldCountry
		case "city":
			plan[i] = fieldCity
		case "latitude":
			plan[i] = fieldLatitude
		case "longitude":
			plan[i] = fieldLongitude
		default:
			return nil, fmt.Errorf("unexpected column: %s", col)
		}
	}

	return plan, nil
}

func scanWithPlan(rows *sql.Rows, plan []eventField, event *Event, dests []any) error {
	for i, f := range plan {
		switch f {
		case fieldID:
			dests[i] = &event.ID
		case fieldTime:
			dests[i] = &event.Time
		case fieldType:
			dests[i] = &event.Type
		case fieldEvent:
			dests[i] = &event.Event
		case fieldRemoteAddr:
			dests[i] = &event.RemoteAddr
		case fieldRemotePort:
			dests[i] = &event.RemotePort
		case fieldDstPort:
			dests[i] = &event.DstPort
		case fieldFields:
			dests[i] = &event.Fields
		case fieldCountry:
			dests[i] = &event.Country
		case fieldCity:
			dests[i] = &event.City
		case fieldLatitude:
			dests[i] = &event.Latitude
		case fieldLongitude:
			dests[i] = &event.Longitude
		}
	}

	return rows.Scan(dests...)
}

type QueryMeta struct {
	Query     string
	WhereArgs []any
	Total     int
	QueryTime time.Duration
}

func (db *Database) QueryEventsMeta(q url.Values) (QueryMeta, error) {
	start := time.Now()

	query, err := db.parseEventQuery(q)
	if err != nil {
		return QueryMeta{}, err
	}

	totalQuery, totalWhereArgs := buildQueryString(query, true)
	rows, err := db.DB.Query(totalQuery, totalWhereArgs...)
	if err != nil {
		return QueryMeta{}, err
	}
	defer rows.Close()

	var total int
	if rows.Next() {
		if err := rows.Scan(&total); err != nil {
			return QueryMeta{}, err
		}
	}

	queryString, whereArgs := buildQueryString(query, false)

	return QueryMeta{
		Query:     queryString,
		WhereArgs: whereArgs,
		Total:     total,
		QueryTime: time.Since(start),
	}, nil
}

func (db *Database) StreamEvents(
	q url.Values,
	handle func(Event) error,
) error {
	query, err := db.parseEventQuery(q)
	if err != nil {
		return err
	}

	queryString, whereArgs := buildQueryString(query, false)

	rows, err := db.DB.Query(queryString, whereArgs...)
	if err != nil {
		return err
	}
	defer rows.Close()

	plan, err := buildFieldPlan(query)
	if err != nil {
		return err
	}

	dests := make([]any, len(plan))

	for rows.Next() {
		var event Event
		if err := scanWithPlan(rows, plan, &event, dests); err != nil {
			return err
		}
		if err := handle(event); err != nil {
			return err
		}
	}

	return rows.Err()
}

func (db *Database) ExportEvents(q url.Values, handle func(Event) error) error {
	query, err := db.parseEventQuery(q)
	if err != nil {
		return err
	}

	hasGeo := HasGeoFields(query)
	queryString := buildSelectClause(query, false, hasGeo)
	whereClauses, whereArgs := buildWhereClauses(query, hasGeo)

	if len(whereClauses) > 0 {
		queryString += " WHERE " + strings.Join(whereClauses, " AND ")
	}

	if query.OrderDirection != "" {
		queryString += " ORDER BY time " + query.OrderDirection
	}

	rows, err := db.DB.Query(queryString, whereArgs...)
	if err != nil {
		return err
	}
	defer rows.Close()

	plan, err := buildFieldPlan(query)
	if err != nil {
		return err
	}

	dests := make([]any, len(plan))

	for rows.Next() {
		var event Event

		if err := scanWithPlan(rows, plan, &event, dests); err != nil {
			return err
		}

		if err := handle(event); err != nil {
			return err
		}
	}

	return rows.Err()
}

func (db *Database) parseEventQuery(query url.Values) (EventQuery, error) {
	limitInt, err := getIntWithDefault(query, "limit", 100)
	if err != nil {
		return EventQuery{}, err
	}

	offset, err := getIntWithDefault(query, "offset", 0)
	if err != nil {
		return EventQuery{}, err
	}

	orderDirection := query.Get("order_direction")
	if orderDirection == "" {
		orderDirection = "desc"
	}

	if orderDirection != "asc" && orderDirection != "desc" {
		return EventQuery{}, fmt.Errorf("invalid order_direction: %s", orderDirection)
	}

	timeStartStr := query.Get("time_start")
	timeStart, err := time.Parse(time.RFC3339, timeStartStr)
	if err != nil {
		timeStart, err = time.Parse(time.DateTime, timeStartStr)
		if err != nil {
			timeStart = time.Time{}
		}
	}

	timeEndStr := query.Get("time_end")
	timeEnd, err := time.Parse(time.RFC3339, timeEndStr)
	if err != nil {
		timeEnd, err = time.Parse(time.DateTime, timeEndStr)
		if err != nil {
			timeEnd = time.Time{}
		}
	}

	// Helper to handle both multiple parameters and comma-separated values
	// for fields where commas are not valid characters in the value itself.
	csvToSlice := func(key string) []string {
		var res []string
		for _, val := range query[key] {
			parts := strings.Split(val, ",")
			for _, p := range parts {
				p = strings.TrimSpace(p)
				if p != "" {
					res = append(res, p)
				}
			}
		}
		return res
	}

	eventQuery := EventQuery{
		Limit:          limitInt,
		Offset:         offset,
		OrderDirection: orderDirection,
		TimeStart:      timeStart,
		TimeEnd:        timeEnd,
		FieldFilters:   make(map[string][]string),
		FieldExists:    []string{},
		IDs:            csvToSlice("id"),
		RemoteAddrs:    csvToSlice("remote_addr"),
		RemotePorts:    csvToSlice("remote_port"),
		DstPorts:       csvToSlice("dst_port"),
		Type:           csvToSlice("type"),
		Event:          csvToSlice("event"),
		Columns:        csvToSlice("columns"),
		ASNs:           csvToSlice("asn"),
		Countries:      csvToSlice("country"),
		Cities:         csvToSlice("city"),
		Domains:        csvToSlice("domain"),
		FQDNs:          csvToSlice("fqdn"),
	}

	for key, values := range query {
		if strings.HasPrefix(key, "f:") {
			field := strings.TrimPrefix(key, "f:")
			eventQuery.FieldFilters[field] = append(eventQuery.FieldFilters[field], values...)
		} else if strings.HasPrefix(key, "fe:") {
			field := strings.TrimPrefix(key, "fe:")
			eventQuery.FieldExists = append(eventQuery.FieldExists, field)
		}
	}

	return eventQuery, nil
}

// getIntWithDefault parses the integer value from a URL value's key, using defaultVal if not set or empty.
func getIntWithDefault(query url.Values, key string, defaultVal int) (int, error) {
	val := query.Get(key)
	if val == "" {
		return defaultVal, nil
	}
	n, err := strconv.Atoi(val)
	if err != nil {
		return 0, err
	}
	return n, nil
}