internal/database/integration_test.go

package database

import (
	"honeypot/internal/types"
	"net/url"
	"os"
	"testing"
	"time"
)

func TestDatabaseIntegration(t *testing.T) {
	// Use an in-memory database for testing
	db := NewDatabase(":memory:")
	if db == nil {
		t.Fatal("Failed to create in-memory database")
	}
	defer db.Close()

	// 1. Test CreateTables (called by NewDatabase)
	// 2. Test InsertEvent
	event := &types.LogEvent{
		Time:       time.Now().UTC().Format(time.RFC3339Nano),
		Type:       "ssh",
		Event:      types.EventAuthAttempt,
		RemoteAddr: "1.2.3.4",
		RemotePort: 12345,
		DstPort:    22,
		Fields:     map[string]interface{}{"username": "admin"},
	}

	err := db.InsertEvent(event)
	if err != nil {
		t.Fatalf("InsertEvent failed: %v", err)
	}

	// 3. Test QueryEventsMeta
	q := url.Values{}
	q.Set("type", "ssh")
	meta, err := db.QueryEventsMeta(q)
	if err != nil {
		t.Fatalf("QueryEventsMeta failed: %v", err)
	}
	if meta.Total != 1 {
		t.Errorf("expected 1 event, got %v", meta.Total)
	}

	// 4. Test StreamEvents
	var events []Event
	err = db.StreamEvents(q, func(e Event) error {
		events = append(events, e)
		return nil
	})
	if err != nil {
		t.Fatalf("StreamEvents failed: %v", err)
	}
	if len(events) != 1 {
		t.Errorf("expected 1 event in stream, got %v", len(events))
	}
	if events[0].RemoteAddr != "1.2.3.4" {
		t.Errorf("expected remote addr 1.2.3.4, got %v", events[0].RemoteAddr)
	}

	// 5. Test GetDashboardStats
	stats, err := db.GetDashboardStats()
	if err != nil {
		t.Fatalf("GetDashboardStats failed: %v", err)
	}
	if stats.Count24h != 1 {
		t.Errorf("expected count_24h = 1, got %v", stats.Count24h)
	}
	if len(stats.StatsAll.RemoteAddrs) != 1 || stats.StatsAll.RemoteAddrs[0].Label != "1.2.3.4" {
		t.Errorf("unexpected top remote addrs: %v", stats.StatsAll.RemoteAddrs)
	}

	// 6. Test GetPortStats
	portStats, err := db.GetTopNFields("dst_port", "type = ?", []any{"ssh"}, 50)
	if err != nil {
		t.Fatalf("GetPortStats failed: %v", err)
	}
	if len(portStats) != 1 || portStats[0].Label != "22" {
		t.Errorf("unexpected port stats: %v", portStats)
	}

	// 7. Test GetSubnetStats
	subnetStats, err := db.GetSubnetStats("1.2.3.0/24")
	if err != nil {
		t.Fatalf("GetSubnetStats failed: %v", err)
	}
	if len(subnetStats) != 1 || subnetStats[0].Label != "1.2.3.4" {
		t.Errorf("unexpected subnet stats: %v", subnetStats)
	}

	// 8. Test GetFirstLastSeenForAddress
	firstSeen, lastSeen, count, err := db.GetFirstLastSeenTotalForNet("non-existent")
	if err == nil {
		t.Fatal("GetFirstLastSeenTotalForSubnet should have failed")
	}
	if firstSeen != "" || lastSeen != "" || count != 0 {
		t.Errorf("expected empty strings for non-existent subnet, got %q, %q, %d", firstSeen, lastSeen, count)
	}
}

func TestNewDatabase(t *testing.T) {
	// Test with empty string
	if db := NewDatabase(""); db != nil {
		t.Error("NewDatabase(\"\") should return nil")
	}

	// Test with a temporary file
	tmpFile := "test_db.duckdb"
	defer os.Remove(tmpFile)
	db := NewDatabase(tmpFile)
	if db == nil {
		t.Fatal("NewDatabase failed with temp file")
	}
	db.Close()
}