internal/honeypot/http/middleware.go

package http

import (
	"context"
	"honeypot/internal/logger"
	"honeypot/internal/types"
	"honeypot/internal/utils"
	"net/http"
)

// portMiddleware injects the destination port into the request context.
func (h *httpHoneypot) portMiddleware(next http.Handler, port uint16) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		dstPort := port

		// If request is from localhost, check for proxy headers
		remoteHost, _ := utils.SplitAddr(r.RemoteAddr, h.logger)
		if remoteHost == "127.0.0.1" || remoteHost == "::1" {
			proto := r.Header.Get("X-Forwarded-Proto")
			switch proto {
			case "https":
				dstPort = 443
			case "http":
				dstPort = 80
			}
		}

		ctx := context.WithValue(r.Context(), dstPortKey, dstPort)
		next.ServeHTTP(w, r.WithContext(ctx))
	})
}

// bodySizeMiddleware limits the request body size.
func (h *httpHoneypot) bodySizeMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Check Content-Length header first for methods that typically have bodies
		if MethodsWithBody[r.Method] {
			if r.ContentLength > h.maxBodySize {
				http.Error(w, "Request Entity Too Large", http.StatusRequestEntityTooLarge)
				return
			}
			// Limit body size for streaming requests
			if r.Body != nil {
				r.Body = http.MaxBytesReader(w, r.Body, h.maxBodySize)
			}
		}
		next.ServeHTTP(w, r)
	})
}

// loggingMiddleware logs all requests.
func (h *httpHoneypot) loggingMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		h.logRequest(r)
		next.ServeHTTP(w, r)
	})
}

// basicAuthMiddleware handles basic authentication.
func (h *httpHoneypot) basicAuthMiddleware(realm string, next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		user, pass, ok := r.BasicAuth()

		fields := h.buildRequestFields(r)
		h.addHeadersToFields(r, fields)

		event := types.LogEvent{}
		if ok {
			fields["username"] = user
			fields["password"] = pass
			fields["auth_type"] = "basic"
			fields["realm"] = realm

			event.Event = types.EventAuthAttempt
		} else {
			event.Event = types.EventRequest
		}

		remoteHost, remotePort := h.getRemoteAddr(r)
		var dstPort uint16
		if port, ok := r.Context().Value(dstPortKey).(uint16); ok {
			dstPort = port
		}
		event.Type = HoneypotType
		event.RemoteAddr = remoteHost
		event.RemotePort = remotePort
		event.DstPort = dstPort
		event.Fields = fields

		logger.LogEvent(h.logger, event)
		h.recordHTTPMetrics(event)

		w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
		h.writeNginxError(w, http.StatusUnauthorized)
	})
}