internal/dashboard/server.go

package dashboard

import (
	"context"
	"crypto/rand"
	"encoding/hex"
	"encoding/json"
	"log/slog"
	"net/http"
	"path"
	"strings"
	"sync"
	"time"

	"honeypot/internal/database"
	"honeypot/internal/geodb"
	"honeypot/internal/honeypot"
	"honeypot/internal/logger"
	"honeypot/internal/utils"
)

// Service manages the dashboard state including websocket hub and log file access.
type Service struct {
	hub        *Hub
	logFile    string
	database   *database.Database
	geodb      *geodb.GeoDB
	honeypots  []honeypot.Honeypot
	scoreCache honeypot.ScoreCache
	uiPassword string
	apiToken   string
	sessions   map[string]time.Time
	sessionMu  sync.RWMutex
	ctx        context.Context
	asnDBFile  string
	cityDBFile string
	asnDBURL   string
	cityDBURL  string
}

// NewService creates a new dashboard service.
func NewService(ctx context.Context, logFile string, database *database.Database, geodb *geodb.GeoDB, honeypots []honeypot.Honeypot, uiPassword string, apiToken string, cfg ServerConfig) *Service {
	s := &Service{
		hub:       NewHub(ctx),
		logFile:   logFile,
		database:  database,
		geodb:     geodb,
		honeypots: honeypots,
		scoreCache: honeypot.ScoreCache{
			Scores:      honeypot.ScoreMap{},
			LastUpdated: time.Time{},
		},
		uiPassword: uiPassword,
		apiToken:   apiToken,
		sessions:   make(map[string]time.Time),
		ctx:        ctx,
		asnDBFile:  cfg.ASNDBFile,
		cityDBFile: cfg.CityDBFile,
		asnDBURL:   cfg.ASNDBURL,
		cityDBURL:  cfg.CityDBURL,
	}

	// Start session cleanup goroutine
	go s.cleanupSessions()

	return s
}

func (s *Service) cleanupSessions() {
	ticker := time.NewTicker(1 * time.Hour)
	defer ticker.Stop()

	for {
		select {
		case <-s.ctx.Done():
			return
		case <-ticker.C:
			s.sessionMu.Lock()
			now := time.Now()
			for id, expiry := range s.sessions {
				if now.After(expiry) {
					delete(s.sessions, id)
				}
			}
			s.sessionMu.Unlock()
		}
	}
}

func (s *Service) createSession() (string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	id := hex.EncodeToString(b)

	s.sessionMu.Lock()
	s.sessions[id] = time.Now().Add(24 * time.Hour)
	s.sessionMu.Unlock()

	return id, nil
}

func (s *Service) isSessionValid(id string) bool {
	s.sessionMu.RLock()
	expiry, ok := s.sessions[id]
	s.sessionMu.RUnlock()

	if !ok {
		return false
	}

	if time.Now().After(expiry) {
		s.sessionMu.Lock()
		delete(s.sessions, id)
		s.sessionMu.Unlock()
		return false
	}

	return true
}

func (s *Service) Hub() *Hub {
	return s.hub
}

func (s *Service) EventSink() *EventSink {
	return NewEventSink(s.hub)
}

// RegisterRoutes registers all dashboard routes on the provided mux.
func (s *Service) RegisterRoutes(mux *http.ServeMux) {
	mux.Handle("GET /", s.withUIAuth(s.FrontendHandler().ServeHTTP))
	mux.Handle("GET /ws", s.withUIAuth(s.ServeWebSocket))

	mux.HandleFunc("POST /api/login", s.Login)
	mux.HandleFunc("POST /api/logout", s.Logout)
	mux.HandleFunc("GET /api/auth-status", s.AuthStatus)

	mux.Handle("GET /api/events", s.withUIAuth(s.ListEvents))
	mux.Handle("GET /api/stats", s.withUIAuth(s.GetStats))
	mux.Handle("GET /api/stats/ip", s.withUIAuth(s.GetIpStats))
	mux.Handle("GET /api/stats/subnet", s.withUIAuth(s.GetSubnetStats))
	mux.Handle("GET /api/stats/port", s.withUIAuth(s.GetPortStats))
	mux.Handle("GET /api/stats/honeypot", s.withUIAuth(s.GetHoneypotStats))
	mux.Handle("GET /api/stats/activity-over-time", s.withUIAuth(s.GetActivityOverTime))
	mux.Handle("GET /api/system-stats", s.withUIAuth(s.GetSystemStats))
	mux.Handle("GET /api/events/export/{format}", s.withUIAuth(s.ExportEvents))
	mux.Handle("GET /api/ipinfo", s.withUIAuth(s.GetIPInfo))
	mux.Handle("GET /api/stats/geo", s.withUIAuth(s.GetGeoStats))
	mux.Handle("GET /api/active-honeypots", s.withUIAuth(s.GetActiveHoneypots))
	mux.Handle("GET /api/blocklist-entries", s.withUIAuth(s.GetBlocklistEntries))
	mux.Handle("POST /api/system/update-geodb", s.withUIAuth(s.TriggerGeoDBUpdate))
	mux.Handle("GET /api/blocklist", withBearerAuth(s.apiToken, http.HandlerFunc(s.GetBlockList)))
}

type ServerConfig struct {
	ListenAddr       string
	UIPort           uint16
	UIPassword       string
	APIToken         string
	LogFile          string
	DisableMetrics   bool
	DisableDashboard bool
	Honeypots        []honeypot.Honeypot
	ASNDBFile        string
	CityDBFile       string
	ASNDBURL         string
	CityDBURL        string
}

func StartServer(
	ctx context.Context,
	cfg ServerConfig,
	l *slog.Logger,
	metricsCollector logger.MetricsCollector,
	honeypots []honeypot.Honeypot,
	database *database.Database,
	geodb *geodb.GeoDB,
	shutdownDone chan struct{},
) {
	if cfg.UIPort == 0 {
		if shutdownDone != nil {
			close(shutdownDone)
		}
		return
	}

	addr := utils.BuildAddress(cfg.ListenAddr, cfg.UIPort)
	mux := http.NewServeMux()

	var svc *Service
	if !cfg.DisableDashboard && database != nil && geodb != nil {
		svc = NewService(ctx, cfg.LogFile, database, geodb, honeypots, cfg.UIPassword, cfg.APIToken, cfg)
	}

	if !cfg.DisableMetrics && metricsCollector != nil {
		logger.LogInfo(l, "dashboard", "metrics_collector_enabled", nil)
		mux.Handle("GET /metrics", withBearerAuth(cfg.APIToken, metricsCollector.GetHandler()))
	}

	server := &http.Server{
		Addr:              addr,
		Handler:           mux,
		ReadHeaderTimeout: 5 * time.Second,
	}
	if svc != nil {
		go svc.Hub().Run()
		logger.RegisterEventSink(svc.EventSink())
		svc.RegisterRoutes(mux)

		// Start server in goroutine
		go func() {
			logger.LogInfo(l, "dashboard", "server_starting", []any{"addr", addr})
			if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
				logger.LogError(l, "dashboard", "server_error", err, []any{"addr", addr})
			}
		}()
		logger.LogInfo(l, "dashboard", "server_listening", []any{"addr", addr})
	} else {
		logger.LogInfo(l, "dashboard", "server_disabled", []any{"reason", "database not configured or GeoLite2 databases not found"})
	}

	// Handle shutdown in goroutine
	go func() {
		defer func() {
			if shutdownDone != nil {
				close(shutdownDone)
			}
		}()

		<-ctx.Done()
		logger.LogInfo(l, "dashboard", "shutdown_initiated", nil)

		// Shutdown hub first
		if svc != nil {
			svc.Hub().Shutdown()
		}

		// Shutdown HTTP server with timeout
		shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()

		if err := server.Shutdown(shutdownCtx); err != nil {
			logger.LogError(l, "dashboard", "server_shutdown_error", err, nil)
		}
		logger.LogInfo(l, "dashboard", "server_shutdown_complete", nil)
	}()
}

const authCookieName = "hp_session"

func (s *Service) withUIAuth(next http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if s.uiPassword == "" {
			next(w, r)
			return
		}

		// Allow login/logout/status and assets without auth
		if r.URL.Path == "/api/login" || r.URL.Path == "/api/logout" || r.URL.Path == "/api/auth-status" {
			next(w, r)
			return
		}

		// Allow static assets (anything with a dot in the base name)
		base := path.Base(r.URL.Path)
		if strings.Contains(base, ".") {
			next(w, r)
			return
		}

		cookie, err := r.Cookie(authCookieName)
		if err != nil || !s.isSessionValid(cookie.Value) {
			if strings.HasPrefix(r.URL.Path, "/api/") {
				http.Error(w, "Unauthorized", http.StatusUnauthorized)
				return
			}
			// For HTML requests, we don't redirect here because the SPA handler will serve /login
			// but we need to ensure the SPA handler knows we are unauthenticated.
			next(w, r)
			return
		}

		next(w, r)
	})
}

func withBearerAuth(token string, next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if token == "" {
			next.ServeHTTP(w, r)
			return
		}

		authHeader := r.Header.Get("Authorization")
		if authHeader == "" {
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
			return
		}

		parts := strings.Split(authHeader, " ")
		if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" || parts[1] != token {
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
			return
		}

		next.ServeHTTP(w, r)
	})
}

func (s *Service) Login(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodPost {
		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
		return
	}

	var req struct {
		Password string `json:"password"`
	}
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, "Invalid request", http.StatusBadRequest)
		return
	}

	if req.Password != s.uiPassword {
		http.Error(w, "Invalid password", http.StatusUnauthorized)
		return
	}

	sessionID, err := s.createSession()
	if err != nil {
		http.Error(w, "Internal server error", http.StatusInternalServerError)
		return
	}

	http.SetCookie(w, &http.Cookie{
		Name:     authCookieName,
		Value:    sessionID,
		Path:     "/",
		HttpOnly: true,
		Secure:   false, // Set to true if using HTTPS
		SameSite: http.SameSiteLaxMode,
		Expires:  time.Now().Add(24 * time.Hour),
	})

	w.WriteHeader(http.StatusOK)
}

func (s *Service) Logout(w http.ResponseWriter, r *http.Request) {
	cookie, err := r.Cookie(authCookieName)
	if err == nil {
		s.sessionMu.Lock()
		delete(s.sessions, cookie.Value)
		s.sessionMu.Unlock()
	}

	http.SetCookie(w, &http.Cookie{
		Name:     authCookieName,
		Value:    "",
		Path:     "/",
		HttpOnly: true,
		MaxAge:   -1,
		Expires:  time.Unix(0, 0),
	})
	w.WriteHeader(http.StatusOK)
}

func (s *Service) AuthStatus(w http.ResponseWriter, r *http.Request) {
	authenticated := false
	cookie, err := r.Cookie(authCookieName)
	if err == nil && s.isSessionValid(cookie.Value) {
		authenticated = true
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(map[string]any{
		"auth_required": s.uiPassword != "",
		"authenticated": authenticated,
	})
}