internal/dashboard/handler.go

package dashboard

import (
	"context"
	"embed"
	"encoding/json"
	"fmt"
	"honeypot/internal/database"
	"honeypot/internal/geodb"
	"honeypot/internal/honeypot"
	"honeypot/internal/types"
	"honeypot/internal/utils"
	"io/fs"
	"net"
	"net/http"
	"os"
	"path"
	"regexp"
	"strings"
	"sync"
	"time"

	"golang.org/x/sync/errgroup"
)

//go:embed frontend/dist
var embeddedDist embed.FS

var spaRoutes = []*regexp.Regexp{
	regexp.MustCompile("^/$"),
	regexp.MustCompile("^/login$"),
	regexp.MustCompile("^/events$"),
	regexp.MustCompile("^/charts$"),
	regexp.MustCompile("^/charts/map$"),
	regexp.MustCompile("^/charts/port$"),
	regexp.MustCompile("^/stats$"),
	regexp.MustCompile("^/ip/.+$"),
	regexp.MustCompile("^/port/.+$"),
	regexp.MustCompile("^/city/.+$"),
	regexp.MustCompile("^/country/.+$"),
	regexp.MustCompile("^/asn/.+$"),
	regexp.MustCompile("^/domain/.+$"),
	regexp.MustCompile("^/fqdn/.+$"),
}

// FrontendHandler serves the embedded SPA frontend.
func (s *Service) FrontendHandler() http.Handler {
	sub, err := fs.Sub(embeddedDist, "frontend/dist")
	if err != nil {
		return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
			w.WriteHeader(http.StatusServiceUnavailable)
			_, _ = w.Write([]byte("dashboard assets not found"))
		})
	}

	fileServer := http.FileServer(http.FS(sub))

	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// serve index for SPA routes
		for _, route := range spaRoutes {
			if route.MatchString(r.URL.Path) {
				r.URL.Path = "/"
				fileServer.ServeHTTP(w, r)
				return
			}
		}

		base := path.Base(r.URL.Path)
		if strings.Contains(base, ".") {
			fileServer.ServeHTTP(w, r)
			return
		}

		r.URL.Path = "/"
		fileServer.ServeHTTP(w, r)
	})
}

// ListEvents handles GET /api/events
func (s *Service) ListEvents(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Transfer-Encoding", "chunked")

	flusher, _ := w.(http.Flusher)
	enc := json.NewEncoder(w)

	meta, err := s.database.QueryEventsMeta(r.URL.Query())
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	// ---- Start JSON object ----
	w.Write([]byte("{"))

	writeJSONField := func(key string, value any) {
		w.Write([]byte(`"` + key + `":`))
		enc.Encode(value)
		w.Write([]byte(","))
	}

	writeJSONField("query", meta.Query)
	writeJSONField("where_args", meta.WhereArgs)
	writeJSONField("total", meta.Total)
	writeJSONField("query_time", meta.QueryTime.String())

	// ---- Start events array ----
	w.Write([]byte(`"events":[`))

	first := true
	err = s.database.StreamEvents(r.URL.Query(), func(event database.Event) error {
		if !first {
			w.Write([]byte(","))
		}
		first = false

		if err := enc.Encode(event); err != nil {
			return err
		}

		if flusher != nil {
			flusher.Flush()
		}
		return nil
	})

	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	// ---- Close array + object ----
	w.Write([]byte("]}"))
}

// GetStats handles GET /api/stats
func (s *Service) GetStats(w http.ResponseWriter, r *http.Request) {
	stats, err := s.database.GetDashboardStats()
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(stats); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetSystemStats handles GET /api/system-stats
func (s *Service) GetSystemStats(w http.ResponseWriter, r *http.Request) {
	dbStats, err := s.database.GetDatabaseStats()
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	res := map[string]any{
		"database":            dbStats,
		"ip_info_last_run":    s.database.IPInfoLastRun,
		"score_cache_updated": s.scoreCache.LastUpdated,
		"geolite_asn_date":    getFileModTime(s.asnDBFile),
		"geolite_city_date":   getFileModTime(s.cityDBFile),
		"geolite_urls_set":    s.asnDBURL != "" && s.cityDBURL != "",
	}

	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(res); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

func getFileModTime(path string) time.Time {
	info, err := os.Stat(path)
	if err != nil {
		return time.Time{}
	}
	return info.ModTime()
}

// TriggerGeoDBUpdate handles POST /api/system/update-geodb
func (s *Service) TriggerGeoDBUpdate(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodPost {
		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
		return
	}

	if s.asnDBURL == "" || s.cityDBURL == "" {
		http.Error(w, "GeoLite download URLs not configured", http.StatusBadRequest)
		return
	}

	// Download new files
	g := errgroup.Group{}
	g.Go(func() error {
		return geodb.DownloadFile(s.asnDBURL, s.asnDBFile)
	})
	g.Go(func() error {
		return geodb.DownloadFile(s.cityDBURL, s.cityDBFile)
	})

	if err := g.Wait(); err != nil {
		http.Error(w, fmt.Sprintf("failed to download updates: %v", err), http.StatusInternalServerError)
		return
	}

	// Reload GeoDB
	if err := s.geodb.Reload(s.asnDBFile, s.cityDBFile); err != nil {
		http.Error(w, fmt.Sprintf("failed to reload GeoDB: %v", err), http.StatusInternalServerError)
		return
	}

	w.WriteHeader(http.StatusOK)
}

// ExportEvents handles GET /api/events/export/{format}
func (s *Service) ExportEvents(w http.ResponseWriter, r *http.Request) {
	format := r.PathValue("format")
	if format != "json" {
		http.Error(w, "invalid format", http.StatusBadRequest)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Content-Disposition", "attachment; filename=events.json")

	// Important: disable buffering in some proxies
	w.Header().Set("Transfer-Encoding", "chunked")

	enc := json.NewEncoder(w)
	flusher, _ := w.(http.Flusher)

	// Start JSON array
	w.Write([]byte("["))

	first := true
	err := s.database.ExportEvents(r.URL.Query(), func(event database.Event) error {
		if !first {
			w.Write([]byte(","))
		}
		first = false

		if err := enc.Encode(event); err != nil {
			return err
		}

		if flusher != nil {
			flusher.Flush()
		}

		return nil
	})

	if err != nil {
		http.Error(w, fmt.Sprintf("failed to export events: %v", err), http.StatusInternalServerError)
		return
	}

	// End JSON array
	w.Write([]byte("]"))
}

type DNSLookupResponse struct {
	IPs       map[string]string `json:"ips"`
	NotFound  []string          `json:"not_found"`
	QueryTime string            `json:"query_time"`
}

// getOrFetchIPMetadataBulk retrieves IP metadata for multiple IPs, fetching and caching any that are missing or stale.
func (s *Service) getOrFetchIPMetadataBulk(ctx context.Context, ips []string) (map[string]*database.IPMetadata, error) {
	if s.database == nil {
		return nil, fmt.Errorf("database not initialized")
	}

	uniqueIPs := make([]string, 0, len(ips))
	seenIP := make(map[string]bool)
	for _, ip := range ips {
		ip = strings.TrimSpace(ip)
		if ip == "" || seenIP[ip] {
			continue
		}
		seenIP[ip] = true
		uniqueIPs = append(uniqueIPs, ip)
	}

	results, err := s.database.GetIPMetadata(uniqueIPs)
	if err != nil {
		return nil, err
	}

	missing := []string{}
	for _, ip := range uniqueIPs {
		meta, ok := results[ip]
		if !ok || time.Since(meta.LastUpdated) > 24*time.Hour {
			missing = append(missing, ip)
		}
	}

	if len(missing) == 0 {
		return results, nil
	}

	// Fetch missing/stale metadata concurrently
	var wg sync.WaitGroup
	var mu sync.Mutex
	for _, ip := range missing {
		wg.Add(1)
		go func(ip string) {
			defer wg.Done()
			if s.geodb == nil {
				return
			}
			fresh, err := s.geodb.LookupMetadata(ctx, ip)
			if err != nil {
				return
			}

			dbMeta := &database.IPMetadata{
				IP:        fresh.IP,
				Country:   fresh.CountryCode,
				ASN:       fresh.ASN,
				ASNOrg:    fresh.ASNOrg,
				City:      fresh.City,
				Latitude:  fresh.Latitude,
				Longitude: fresh.Longitude,
				FQDN:      fresh.FQDN,
				Domain:    fresh.Domain,
			}

			ipInt, err := utils.IPToInt(ip)
			if err == nil {
				dbMeta.IPInt = ipInt
			}
			if err := s.database.UpsertIPMetadata(dbMeta); err == nil {
				// Reload from DB to get the correct LastUpdated
				m, _ := s.database.GetIPMetadata([]string{ip})
				if meta, ok := m[ip]; ok {
					mu.Lock()
					results[ip] = meta
					mu.Unlock()
				}
			}
		}(ip)
	}
	wg.Wait()

	return results, nil
}

// GetIPInfo handles GET /api/ipinfo
func (s *Service) GetIPInfo(w http.ResponseWriter, r *http.Request) {
	startTime := time.Now()
	ipQuery := r.URL.Query().Get("ip")
	if ipQuery == "" {
		http.Error(w, "ip required", http.StatusBadRequest)
		return
	}

	ips := strings.Split(ipQuery, ",")
	metas, err := s.getOrFetchIPMetadataBulk(r.Context(), ips)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	results := make(map[string]any)
	var notFound []string

	for _, ip := range ips {
		ip = strings.TrimSpace(ip)
		if ip == "" {
			continue
		}

		if meta, ok := metas[ip]; ok {
			results[ip] = map[string]any{
				"asn": map[string]any{
					"autonomous_system_number":       meta.ASN,
					"autonomous_system_organization": meta.ASNOrg,
				},
				"city": map[string]any{
					"name":    meta.City,
					"country": meta.Country,
				},
				"country": map[string]any{
					"iso_code": meta.Country,
				},
				"location": map[string]any{
					"latitude":  meta.Latitude,
					"longitude": meta.Longitude,
				},
				"fqdn":   meta.FQDN,
				"domain": meta.Domain,
			}
		} else {
			notFound = append(notFound, ip)
		}
	}

	w.Header().Set("Content-Type", "application/json; charset=utf-8")
	_ = json.NewEncoder(w).Encode(map[string]any{
		"results":    results,
		"not_found":  notFound,
		"query_time": time.Since(startTime).String(),
	})
}

// GetPortStats handles GET /api/stats/port
func (s *Service) GetPortStats(w http.ResponseWriter, r *http.Request) {
	portStr := r.URL.Query().Get("port")
	if portStr == "" {
		http.Error(w, "port required", http.StatusBadRequest)
		return
	}

	var port int
	if _, err := fmt.Sscanf(portStr, "%d", &port); err != nil {
		http.Error(w, "invalid port", http.StatusBadRequest)
		return
	}

	topAddrs, err := s.database.GetTopNFields("remote_addr", "type = 'packetlogger' AND dst_port = ?", []any{port}, 50)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	topEvents, err := s.database.GetTopNFields("event", "dst_port = ?", []any{port}, 50)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	firstSeen, lastSeen, count, err := s.database.GetPortStatsOverview(port)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	res := map[string]any{
		"top_addrs":    topAddrs,
		"top_events":   topEvents,
		"total_events": count,
		"first_seen":   firstSeen,
		"last_seen":    lastSeen,
	}

	if err := json.NewEncoder(w).Encode(res); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetIpStats handles GET /api/stats/ip
func (s *Service) GetIpStats(w http.ResponseWriter, r *http.Request) {
	ip := r.URL.Query().Get("ip")
	if ip == "" {
		http.Error(w, "ip required", http.StatusBadRequest)
		return
	}

	g, _ := errgroup.WithContext(r.Context())

	var firstSeen, lastSeen string
	var count int
	var topPorts []database.LabelCount
	var topEvents []database.LabelCount
	var blocklist []types.BlocklistEntry

	g.Go(func() error {
		var err error
		firstSeen, lastSeen, count, err = s.database.GetFirstLastSeenTotalForNet(ip)
		return err
	})

	var where string
	var args []any
	var err error

	if where, args, err = database.GetIpWhere(ip); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	g.Go(func() error {
		var err error
		topPorts, err = s.database.GetTopNFields("dst_port", "type = 'packetlogger' AND "+where, args, 50)
		return err
	})

	g.Go(func() error {
		var err error
		topEvents, err = s.database.GetTopNFields("event", where, args, 50)
		return err
	})

	g.Go(func() error {
		var err error
		blocklist, err = s.database.GetBlocklistForNet(ip, 50)
		return err
	})

	if err := g.Wait(); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(map[string]any{
		"total_events": count,
		"first_seen":   firstSeen,
		"last_seen":    lastSeen,
		"top_ports":    topPorts,
		"top_events":   topEvents,
		"blocklist":    blocklist,
	}); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetGeoStats handles GET /api/stats/geo
func (s *Service) GetGeoStats(w http.ResponseWriter, r *http.Request) {
	geoType := r.URL.Query().Get("type")
	value := r.URL.Query().Get("value")

	if geoType == "" || value == "" {
		http.Error(w, "type and value required", http.StatusBadRequest)
		return
	}

	stats, err := s.database.GetGeoStats(geoType, value)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(map[string]any{
		"title":          stats.Title,
		"metadata":       stats.Metadata,
		"top_addrs":      stats.RemoteAddrs,
		"top_events":     stats.EventTypes,
		"top_ports":      stats.Ports,
		"top_subdomains": stats.FQDNs,
		"total_events":   stats.TotalEvents,
		"first_seen":     stats.FirstSeen,
		"last_seen":      stats.LastSeen,
	}); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetHoneypotStats handles GET /api/stats/honeypot
func (s *Service) GetHoneypotStats(w http.ResponseWriter, r *http.Request) {
	eventType := r.URL.Query().Get("event_type")
	if eventType == "" {
		http.Error(w, "event_type required", http.StatusBadRequest)
		return
	}

	stats, err := s.database.GetHoneypotStats(eventType)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(stats); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetSubnetStats handles GET /api/stats/subnet
func (s *Service) GetSubnetStats(w http.ResponseWriter, r *http.Request) {
	ipStr := r.URL.Query().Get("ip")
	maskStr := r.URL.Query().Get("mask")

	if ipStr == "" || maskStr == "" {
		http.Error(w, "ip and mask required", http.StatusBadRequest)
		return
	}

	ip := net.ParseIP(ipStr)
	if ip == nil {
		http.Error(w, "invalid ip", http.StatusBadRequest)
		return
	}

	maskInt := 24
	fmt.Sscanf(maskStr, "%d", &maskInt)

	_, ipNet, err := net.ParseCIDR(fmt.Sprintf("%s/%d", ipStr, maskInt))
	if err != nil {
		// Fallback for cases where ParseCIDR fails (e.g. invalid mask)
		http.Error(w, "invalid subnet configuration", http.StatusBadRequest)
		return
	}

	stats, err := s.database.GetSubnetStats(ipNet.String())
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(stats); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetActivityOverTime handles GET /api/stats/activity-over-time
func (s *Service) GetActivityOverTime(w http.ResponseWriter, r *http.Request) {
	activity, err := s.database.GetActivityOverTime(r.URL.Query())
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(activity); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetActiveHoneypots handles GET /api/active-honeypots
func (s *Service) GetActiveHoneypots(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "application/json")

	honeypots := []map[string]any{}
	for _, hp := range s.honeypots {
		honeypots = append(honeypots, map[string]any{
			"name":  string(hp.Name()),
			"label": hp.Label(),
			"ports": hp.Ports(),
		})
	}
	if err := json.NewEncoder(w).Encode(honeypots); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetBlocklistEntries handles GET /api/blocklist-entries
func (s *Service) GetBlocklistEntries(w http.ResponseWriter, r *http.Request) {
	updateScoreCache(s)
	blocklist, err := s.database.GetBlockedAddresses()
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	if err := json.NewEncoder(w).Encode(blocklist); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}

// GetBlockList handles GET /api/blocklist
// this is a list of ip addresses or subnets that are in the scorelist
// one ip or subnet per line for ingestion into OPNsense
func (s *Service) GetBlockList(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Content-Type", "text/plain")
	updateScoreCache(s)

	blocklist, err := s.database.GetBlockedAddresses()
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	for _, entry := range blocklist {
		w.Write([]byte(entry.Address + "\n"))
	}
}

func updateScoreCache(s *Service) {
	if time.Since(s.scoreCache.LastUpdated) < 2*time.Minute {
		return
	}

	scores := honeypot.ScoreMap{}
	for _, hp := range s.honeypots {
		scores = honeypot.MergeScores(scores, hp.GetScores(s.database, "'3 HOURS'"))
	}
	scores = honeypot.MergeScores(scores, honeypot.GetAuthAttemptScores(s.database, "'3 HOURS'"))

	// only return scores > 300 points
	for ip, score := range scores {
		if score.Score < 300 {
			delete(scores, ip)
		}
	}

	honeypot.UpdateBlocklist(s.database, scores)

	s.scoreCache.Scores = scores
	s.scoreCache.LastUpdated = time.Now()
}